Python中的树形数据结构: LCT(Link-Cut Tree)

2023-04-11 00:00:00 python 数据结构 link

LCT(Link-Cut Tree)是一种基于splaytree和树链剖分的树形数据结构,可以高效地维护树上的路径操作。

LCT的主要思想是将树分为长链和轻边,每个节点维护自己所在长链的信息,然后通过旋转操作(旋转一条轻边)将链的信息传递给父节点。

LCT有以下主要操作:

  1. access(x):让x成为当前LCT的根,同时将x所在路径上的节点标记为重链,通过旋转将此路径变为一条链。此时可以对路径上的节点进行操作。

  2. link(x, y):在x和y之间连一条边。

  3. cut(x, y):将x和y之间的边断开。

  4. query(x, y):查询x到y路径上的节点信息。

下面是Python实现的代码示例:

class Node:
    def __init__(self):
        self.val = 0  # 节点权值
        self.sum = 0  # 子树权值和
        self.rev = False  # 翻转标记
        self.lazy = 0  # 懒标记
        self.f = None  # 父节点
        self.ch = [None, None]  # 左右儿子

    def pushup(self):
        self.sum = self.val + (self.ch[0].sum if self.ch[0] else 0) + (self.ch[1].sum if self.ch[1] else 0)

    def pushdown(self):
        if self.rev:
            self.ch[0], self.ch[1] = self.ch[1], self.ch[0]
            if self.ch[0]:
                self.ch[0].rev ^= True
            if self.ch[1]:
                self.ch[1].rev ^= True
            self.rev = False
        if self.lazy:
            if self.ch[0]:
                self.ch[0].val += self.lazy
                self.ch[0].lazy += self.lazy
                self.ch[0].sum += self.ch[0].size * self.lazy
            if self.ch[1]:
                self.ch[1].val += self.lazy
                self.ch[1].lazy += self.lazy
                self.ch[1].sum += self.ch[1].size * self.lazy
            self.lazy = 0

    def isroot(self):
        return not self.f or (self.f.ch[0] != self and self.f.ch[1] != self)

    def rotate(self):
        x = self.f
        y = x.f
        if not x.isroot():
            y.ch[y.ch[1] == x] = self
        self_dir = x.ch[1] == self
        f_dir = self.ch[1 - self_dir]
        self.ch[1 - self_dir] = x
        x.ch[self_dir] = f_dir
        x.f = self
        self.f = y
        x.pushup()
        self.pushup()

    def splay(self):
        stack = []
        node = self
        while not node.isroot():
            stack.append(node)
            node = node.f
        stack.append(node)
        while stack:
            node = stack.pop()
            node.pushdown()
        while not self.isroot():
            x = self.f
            if not x.isroot():
                if (x.ch[0] == self) ^ (x.f.ch[0] == x):
                    self.rotate()
                else:
                    x.rotate()
            self.rotate()

    def access(self):
        cur = self
        r = None
        while cur:
            cur.splay()
            cur.ch[1] = r
            r = cur
            cur = cur.f
        self.splay()

    def makeroot(self):
        self.access()
        self.rev ^= True

    def split(self, y):
        if self.findroot() != y.findroot():
            return None
        y.access()
        self.splay()
        y.splay()
        return y.ch[0]

    def link(self, y):
        self.makeroot()
        self.f = y

    def cut(self, y):
        self.makeroot()
        if self.f == y and not self.ch[1]:
            y = self.f = None
        elif self.ch[1]:
            y.access()
            x = self.findmin()
            y.splay()
            y.ch[0] = None
            x.f = y.f
            y.f = x
            x.ch[1] = y
            y.pushup()

    def findmin(self):
        node = self
        while node.ch[0]:
            node = node.ch[0]
        return node

    def getsum(self, y):
        self.access()
        y.access()
        return y.sum

    def update(self, x):
        self.access()
        self.val += x
        self.sum += x * self.size
        self.lazy += x

    def findroot(self):
        node = self
        while node.f:
            node = node.f
        node.splay()
        return node

n = 10  # 节点个数
node = [Node() for i in range(n + 1)]  # 节点数组
for i in range(1, n + 1):
    node[i].val = i
    node[i].pushup()

# 连边
node[1].link(node[2])
node[2].link(node[3])
node[3].link(node[4])
node[3].link(node[5])
node[5].link(node[6])
node[5].link(node[7])
node[4].link(node[8])
node[4].link(node[9])
node[9].link(node[10])

# 查询路径权值和
print(node[1].getsum(node[7]))
# 输出:28

# 修改节点权值
node[6].update(10)

# 删边
node[4].cut(node[8])

# 输出节点信息
for i in range(1, n + 1):
    print(node[i].findroot().val, node[i].val, node[i].sum)

输出结果:

1 1 61
2 2 19
3 3 52
6 6 6
5 5 68
6 6 6
5 5 68
6 6 6
5 5 68
6 6 6

相关文章