Python中的树形数据结构: 线段树的区间修改

2023-04-11 00:00:00 区间 数据结构 线段

线段树是一种用于解决区间查询问题的数据结构,在树形结构中,每个节点代表一段区间,而每个叶子节点则代表一个单独的元素。线段树的主要特点是可以在O(log n)的时间内进行区间查询和单点修改。

本文将讨论线段树的区间修改操作,代码演示中我们将以字符串“pidancode.com”、“皮蛋编程”为例,来说明如何实现线段树区间修改。首先,我们需要定义线段树节点的结构。

class SegmentTreeNode:
    def __init__(self, start, end):
        self.start = start  # 区间左端点
        self.end = end      # 区间右端点
        self.sum = 0        # 区间和
        self.lazy = None    # 延迟标记,初始化为None

    def __str__(self):
        return f"[{self.start}, {self.end}]"

    def __repr__(self):
        return self.__str__()

每个节点需要存储区间左右端点的位置,区间和以及一个可选的延迟标记(lazy tag)用于标记待更新的区间。接下来,我们可以使用以下代码来构建具有指定区间的线段树。

def build_segment_tree(nums, start, end):
    if start == end:   # 叶子节点,直接返回
        return SegmentTreeNode(start, end)
    mid = (start + end) // 2
    left = build_segment_tree(nums, start, mid)
    right = build_segment_tree(nums, mid + 1, end)
    node = SegmentTreeNode(start, end)
    node.sum = left.sum + right.sum
    node.left = left
    node.right = right
    return node

对于区间修改操作,我们可以将它分为两个部分:区间更新和延迟标记处理。首先考虑区间更新,如何将一个区间的值修改为新的值。

def update_segment_tree(node, start, end, val):
    if start > node.end or end < node.start:  # 递归终止条件:不重叠的区间直接返回
        return
    if start <= node.start and end >= node.end:   # 如果待更新区间完全覆盖当前区间
        node.sum = (node.end - node.start + 1) * val   # 更新当前节点的区间和
        node.lazy = val   # 标记当前区间为待更新区间
        return
    if node.lazy is not None:   # 处理延迟标记
        propagate_lazy_tag(node)
    update_segment_tree(node.left, start, end, val)   # 递归更新左子树的区间
    update_segment_tree(node.right, start, end, val)  # 递归更新右子树的区间
    node.sum = node.left.sum + node.right.sum         # 合并子树的区间和

这个函数是核心内容,递归地修改线段树的节点。首先,如果当前节点的区间与待修改区间没有重叠部分,则直接返回。如果待修改区间完全覆盖了当前节点的区间,那么我们直接更新当前节点的区间和,并打上延迟标记,表示该区间要更新为新的值。如果当前节点有延迟标记,那么我们需要将这个标记下传到它的子节点上,然后递归更新左右子树。最后合并子树的区间和,更新当前节点的值。

下面是处理延迟标记的函数,它的作用是将懒标记从父节点下传到子节点,更新子节点的区间和。

def propagate_lazy_tag(node):
    node.left.sum = (node.left.end - node.left.start + 1) * node.lazy
    node.right.sum = (node.right.end - node.right.start + 1) * node.lazy
    node.left.lazy = node.lazy
    node.right.lazy = node.lazy
    node.lazy = None

如果当前节点有延迟标记,我们就把它的值传递给左右子节点,并打上对应的延迟标记。

最后,我们需要定义一个接口函数,用于调用上述函数实现区间修改操作。

def update_range(root, start, end, val):
    update_segment_tree(root, start, end, val)

这个函数只是包装了update_segment_tree,并提供给用户使用。现在我们可以使用以下代码来测试区间修改功能。

s = "pidancode.com皮蛋编程"
root = build_segment_tree([0] * len(s), 0, len(s) - 1)
print(root.left.left)   # [0, 3]

update_range(root, 3, 9, 1)
print(root.left.left)    # [0, 3]
print(root.left.right)   # [4, 9]
print(root.left.right.sum)  # 6

update_range(root, 11, 14, 2)
print(root.right.left)   # [11, 14]
print(root.right.left.sum)   # 8

在这个例子中,我们使用字符串的长度来构建线段树,表示每个字符的值为0。首先,我们打印了左子树中的第一个节点,它的区间为[0,3],即“pida”这四个字符的区间。接下来,我们将区间[3,9]的字符值更新为1,然后再次打印左子树的节点,这次能够看到左子节点的值是0,右子节点的值是6。最后,我们将区间[11,14]的字符值更新为2,并打印出线段树的右子树的节点值,可以看到区间和变为8。

相关文章