Python递归实现线段树数据结构

2023-04-16 00:00:00 递归 数据结构 线段

线段树是一种二叉树数据结构,通常用于解决处理数组区间查询的问题,如区间求和、区间最大值等。
Python递归实现线段树的步骤如下:
1. 定义线段树结构体,包括左子树、右子树、区间范围和存储的值等属性。

class SegmentTree:
    def __init__(self, start, end):
        self.start = start # 区间左端点
        self.end = end # 区间右端点
        self.left = None # 左子树
        self.right = None # 右子树
        self.value = 0 # 存储值
  1. 实现线段树的建树函数,即递归构造线段树。若区间长度为1,则存储该位置的元素值;否则将区间分为两半,递归构造左子树和右子树,并在当前节点存储左右子树的值之和。
def build_tree(start, end, nums):
    if start == end:
        node = SegmentTree(start, end)
        node.value = nums[start]
        return node
    mid = (start + end) // 2
    node = SegmentTree(start, end)
    node.left = build_tree(start, mid, nums) # 左子树递归构造
    node.right = build_tree(mid+1, end, nums) # 右子树递归构造
    node.value = node.left.value + node.right.value # 当前节点存储左右子树的值之和
    return node
  1. 实现查询函数,即递归查询某一区间的值。如果查询区间与当前节点区间没有交集,则返回0;如果查询区间包含当前节点区间,则返回当前节点存储的值;否则将查询区间分为两半,递归查询左子树和右子树,并返回左右子树的值之和。
def query(tree, start, end):
    if start > tree.end or end < tree.start:
        return 0
    if start <= tree.start and tree.end <= end:
        return tree.value
    return query(tree.left, start, end) + query(tree.right, start, end)
  1. 实现更新函数,即递归修改某一位置的值。如果修改位置不在当前节点区间范围内,则返回;否则如果当前节点为叶子节点,则更新该位置的值;否则将修改位置分为左子树和右子树,递归更新。
def update(tree, index, val):
    if tree.start == tree.end:
        tree.value = val
        return
    mid = (tree.start + tree.end) // 2
    if index <= mid:
        update(tree.left, index, val) # 修改位置在左子树内,递归更新左子树
    else:
        update(tree.right, index, val) # 修改位置在右子树内,递归更新右子树
    tree.value = tree.left.value + tree.right.value # 更新当前节点的值

完整代码演示如下:

class SegmentTree:
    def __init__(self, start, end):
        self.start = start # 区间左端点
        self.end = end # 区间右端点
        self.left = None # 左子树
        self.right = None # 右子树
        self.value = 0 # 存储值
def build_tree(start, end, nums):
    if start == end:
        node = SegmentTree(start, end)
        node.value = nums[start]
        return node
    mid = (start + end) // 2
    node = SegmentTree(start, end)
    node.left = build_tree(start, mid, nums) # 左子树递归构造
    node.right = build_tree(mid+1, end, nums) # 右子树递归构造
    node.value = node.left.value + node.right.value # 当前节点存储左右子树的值之和
    return node
def query(tree, start, end):
    if start > tree.end or end < tree.start:
        return 0
    if start <= tree.start and tree.end <= end:
        return tree.value
    return query(tree.left, start, end) + query(tree.right, start, end)
def update(tree, index, val):
    if tree.start == tree.end:
        tree.value = val
        return
    mid = (tree.start + tree.end) // 2
    if index <= mid:
        update(tree.left, index, val) # 修改位置在左子树内,递归更新左子树
    else:
        update(tree.right, index, val) # 修改位置在右子树内,递归更新右子树
    tree.value = tree.left.value + tree.right.value # 更新当前节点的值
# 范例代码
nums = [1, 3, 5, 7, 9, 11, 13]
tree = build_tree(0, len(nums)-1, nums)
print(query(tree, 1, 3)) # 输出:15
update(tree, 2, 8)
print(query(tree, 1, 3)) # 输出:18

相关文章