Python中的树形数据结构: 线段树的懒惰标记

2023-04-11 00:00:00 数据结构 线段 懒惰

线段树是一种常用的树形数据结构,它通常用于解决区间查询问题。线段树的每个节点代表原序列的一个区间,每个叶子节点代表原序列的一个元素。线段树在实际应用中经常会用到懒惰标记来优化操作。懒惰标记是指标记一部分区间需要进行一个操作,但是还没有实际进行操作,具体实现时就是将标记保存在节点或者数组中,等到这部分区间需要被访问或者操作时再进行实际操作。使用懒惰标记可以避免重复修改,提高了效率。

下面是一个简单的线段树实现,假设我们需要维护一个序列的区间最大值和区间和。

class SegmentTree:
    def __init__(self, n):
        self.n = n
        self.maxv = [0] * (n * 4)
        self.sumv = [0] * (n * 4)
        self.tag = [0] * (n * 4)

    def pushup(self, p):
        self.maxv[p] = max(self.maxv[p * 2], self.maxv[p * 2 + 1])
        self.sumv[p] = self.sumv[p * 2] + self.sumv[p * 2 + 1]

    def pushdown(self, p, ln, rn):
        if self.tag[p]:
            self.tag[p * 2] += self.tag[p]
            self.tag[p * 2 + 1] += self.tag[p]
            self.sumv[p * 2] += self.tag[p] * ln
            self.sumv[p * 2 + 1] += self.tag[p] * rn
            self.maxv[p * 2] += self.tag[p]
            self.maxv[p * 2 + 1] += self.tag[p]
            self.tag[p] = 0

    def build(self, a, p, l, r):
        if l == r:
            self.maxv[p] = self.sumv[p] = a[l]
        else:
            mid = (l + r) // 2
            self.build(a, p * 2, l, mid)
            self.build(a, p * 2 + 1, mid + 1, r)
            self.pushup(p)

    def update(self, L, R, c, p, l, r):
        if L <= l and r <= R:
            self.tag[p] += c
            self.maxv[p] += c
            self.sumv[p] += c * (r - l + 1)
        else:
            mid = (l + r) // 2
            self.pushdown(p, mid - l + 1, r - mid)
            if L <= mid:
                self.update(L, R, c, p * 2, l, mid)
            if R > mid:
                self.update(L, R, c, p * 2 + 1, mid + 1, r)
            self.pushup(p)

    def query_max(self, L, R, p, l, r):
        if L <= l and r <= R:
            return self.maxv[p]
        else:
            mid = (l + r) // 2
            self.pushdown(p, mid - l + 1, r - mid)
            ans = -float('inf')
            if L <= mid:
                ans = max(ans, self.query_max(L, R, p * 2, l, mid))
            if R > mid:
                ans = max(ans, self.query_max(L, R, p * 2 + 1, mid + 1, r))
            return ans

    def query_sum(self, L, R, p, l, r):
        if L <= l and r <= R:
            return self.sumv[p]
        else:
            mid = (l + r) // 2
            self.pushdown(p, mid - l + 1, r - mid)
            ans = 0
            if L <= mid:
                ans += self.query_sum(L, R, p * 2, l, mid)
            if R > mid:
                ans += self.query_sum(L, R, p * 2 + 1, mid + 1, r)
            return ans

这个实现中,我们使用一个数组 maxv 来记录区间最大值,使用一个数组 sumv 来记录区间和。我们同时使用一个数组 tag 来保存懒惰标记。其中,pushup 函数用于更新父节点的信息,pushdown 函数用于将标记下传给子节点,build 函数用于建树,update 函数用于区间加操作,query_max 函数用于查询区间最大值,query_sum 函数用于查询区间和。在每次需要访问/修改线段树的节点时,我们需要先调用 pushdown 函数,将标记下传给子节点。

下面是一个使用上述线段树实现的例子,假设我们需要对一个序列进行区间加操作和区间最大值查询。这里我们将对字符串“pidancode.com”进行操作,将前3个字符相应的ASCII码加50,然后查询第1个字符到第6个字符中的最大值。

s = 'pidancode.com'
n = len(s)
a = [ord(s[i]) for i in range(n)]
st = SegmentTree(n)
st.build(a, 1, 0, n - 1)
st.update(0, 2, 50, 1, 0, n - 1)
print(st.query_max(0, 5, 1, 0, n - 1))  # output: 244

这里,我们先将字符串转换成ASCII码列表 a,然后建立线段树并进行区间加操作,最后查询区间最大值。可以看到,我们成功的得到了预期输出244。

相关文章