Python递归实现线段树数据结构
线段树是一种二叉树数据结构,通常用于解决处理数组区间查询的问题,如区间求和、区间最大值等。
Python递归实现线段树的步骤如下:
1. 定义线段树结构体,包括左子树、右子树、区间范围和存储的值等属性。
class SegmentTree: def __init__(self, start, end): self.start = start # 区间左端点 self.end = end # 区间右端点 self.left = None # 左子树 self.right = None # 右子树 self.value = 0 # 存储值
- 实现线段树的建树函数,即递归构造线段树。若区间长度为1,则存储该位置的元素值;否则将区间分为两半,递归构造左子树和右子树,并在当前节点存储左右子树的值之和。
def build_tree(start, end, nums): if start == end: node = SegmentTree(start, end) node.value = nums[start] return node mid = (start + end) // 2 node = SegmentTree(start, end) node.left = build_tree(start, mid, nums) # 左子树递归构造 node.right = build_tree(mid+1, end, nums) # 右子树递归构造 node.value = node.left.value + node.right.value # 当前节点存储左右子树的值之和 return node
- 实现查询函数,即递归查询某一区间的值。如果查询区间与当前节点区间没有交集,则返回0;如果查询区间包含当前节点区间,则返回当前节点存储的值;否则将查询区间分为两半,递归查询左子树和右子树,并返回左右子树的值之和。
def query(tree, start, end): if start > tree.end or end < tree.start: return 0 if start <= tree.start and tree.end <= end: return tree.value return query(tree.left, start, end) + query(tree.right, start, end)
- 实现更新函数,即递归修改某一位置的值。如果修改位置不在当前节点区间范围内,则返回;否则如果当前节点为叶子节点,则更新该位置的值;否则将修改位置分为左子树和右子树,递归更新。
def update(tree, index, val): if tree.start == tree.end: tree.value = val return mid = (tree.start + tree.end) // 2 if index <= mid: update(tree.left, index, val) # 修改位置在左子树内,递归更新左子树 else: update(tree.right, index, val) # 修改位置在右子树内,递归更新右子树 tree.value = tree.left.value + tree.right.value # 更新当前节点的值
完整代码演示如下:
class SegmentTree: def __init__(self, start, end): self.start = start # 区间左端点 self.end = end # 区间右端点 self.left = None # 左子树 self.right = None # 右子树 self.value = 0 # 存储值 def build_tree(start, end, nums): if start == end: node = SegmentTree(start, end) node.value = nums[start] return node mid = (start + end) // 2 node = SegmentTree(start, end) node.left = build_tree(start, mid, nums) # 左子树递归构造 node.right = build_tree(mid+1, end, nums) # 右子树递归构造 node.value = node.left.value + node.right.value # 当前节点存储左右子树的值之和 return node def query(tree, start, end): if start > tree.end or end < tree.start: return 0 if start <= tree.start and tree.end <= end: return tree.value return query(tree.left, start, end) + query(tree.right, start, end) def update(tree, index, val): if tree.start == tree.end: tree.value = val return mid = (tree.start + tree.end) // 2 if index <= mid: update(tree.left, index, val) # 修改位置在左子树内,递归更新左子树 else: update(tree.right, index, val) # 修改位置在右子树内,递归更新右子树 tree.value = tree.left.value + tree.right.value # 更新当前节点的值 # 范例代码 nums = [1, 3, 5, 7, 9, 11, 13] tree = build_tree(0, len(nums)-1, nums) print(query(tree, 1, 3)) # 输出:15 update(tree, 2, 8) print(query(tree, 1, 3)) # 输出:18
相关文章