使用Python实现树形数据结构的AVL-S树

2023-04-11 00:00:00 python 数据结构 AVL

AVL-S树是一种平衡树,它使用自平衡来确保树的高度不超过log(n),从而保证O(log n)的查找和插入时间。

AVL-S树和AVL树相似,都是基于平衡因子的概念来实现平衡的,不同的是,AVL-S树使用了第二关键字来对节点进行排序。在一个节点有多个关键字的情况下,我们可以使用第一个关键字进行常规的查找,然后使用第二个关键字来对节点进行排序。

下面是一个使用Python实现的AVL-S树,我们假设每个节点有两个关键字,第一个关键字是字符串,第二个关键字是整数。我们以节点的第二个关键字来对节点进行排序。

class AVLSTreeNode:
    def __init__(self, key1, key2):
        self.key1 = key1
        self.key2 = key2
        self.left = None
        self.right = None
        self.height = 1
        self.size = 1

class AVLSTree:
    def __init__(self):
        self.root = None

    def get_height(self, node):
        if node is None:
            return 0
        else:
            return node.height

    def get_size(self, node):
        if node is None:
            return 0
        else:
            return node.size

    def get_balance_factor(self, node):
        if node is None:
            return 0
        else:
            return self.get_height(node.left) - self.get_height(node.right)

    def get_min_node(self, node):
        while node.left is not None:
            node = node.left
        return node

    def right_rotate(self, node):
        left_node = node.left
        right_child = left_node.right

        left_node.right = node
        node.left = right_child

        node.height = max(self.get_height(node.left), self.get_height(node.right)) + 1
        left_node.height = max(self.get_height(left_node.left), self.get_height(left_node.right)) + 1

        node.size = self.get_size(node.left) + self.get_size(node.right) + 1
        left_node.size = self.get_size(left_node.left) + self.get_size(left_node.right) + 1

        return left_node


    def left_rotate(self, node):
        right_node = node.right
        left_child = right_node.left

        right_node.left = node
        node.right = left_child

        node.height = max(self.get_height(node.left), self.get_height(node.right)) + 1
        right_node.height = max(self.get_height(right_node.left), self.get_height(right_node.right)) + 1

        node.size = self.get_size(node.left) + self.get_size(node.right) + 1
        right_node.size = self.get_size(right_node.left) + self.get_size(right_node.right) + 1

        return right_node

    def insert(self, key1, key2):
        def _insert(node, key1, key2):
            if node is None:
                return AVLSTreeNode(key1, key2)

            if key2 < node.key2:
                node.left = _insert(node.left, key1, key2)
            elif key2 > node.key2:
                node.right = _insert(node.right, key1, key2)
            else:
                return node

            node.height = max(self.get_height(node.left), self.get_height(node.right)) + 1

            node.size = self.get_size(node.left) + self.get_size(node.right) + 1

            balance_factor = self.get_balance_factor(node)

            if balance_factor > 1 and key2 < node.left.key2:
                return self.right_rotate(node)

            if balance_factor < -1 and key2 > node.right.key2:
                return self.left_rotate(node)

            if balance_factor > 1 and key2 > node.left.key2:
                node.left = self.left_rotate(node.left)
                return self.right_rotate(node)

            if balance_factor < -1 and key2 < node.right.key2:
                node.right = self.right_rotate(node.right)
                return self.left_rotate(node)

            return node

        self.root = _insert(self.root, key1, key2)

    def delete(self, key1, key2):
        def _delete(node, key1, key2):
            if node is None:
                return None

            if key2 < node.key2:
                node.left = _delete(node.left, key1, key2)
            elif key2 > node.key2:
                node.right = _delete(node.right, key1, key2)
            else:
                if node.left is None:
                    right_child = node.right
                    node = None
                    return right_child
                elif node.right is None:
                    left_child = node.left
                    node = None
                    return left_child
                else:
                    temp_node = self.get_min_node(node.right)
                    node.key1 = temp_node.key1
                    node.key2 = temp_node.key2
                    node.right = _delete(node.right, temp_node.key1, temp_node.key2)

            if node is None:
                return None

            node.height = max(self.get_height(node.left), self.get_height(node.right)) + 1

            node.size = self.get_size(node.left) + self.get_size(node.right) + 1

            balance_factor = self.get_balance_factor(node)

            if balance_factor > 1 and self.get_balance_factor(node.left) >= 0:
                return self.right_rotate(node)

            if balance_factor < -1 and self.get_balance_factor(node.right) <= 0:
                return self.left_rotate(node)

            if balance_factor > 1 and self.get_balance_factor(node.left) < 0:
                node.left = self.left_rotate(node.left)
                return self.right_rotate(node)

            if balance_factor < -1 and self.get_balance_factor(node.right) > 0:
                node.right = self.right_rotate(node.right)
                return self.left_rotate(node)

            return node

        self.root = _delete(self.root, key1, key2)

    def search(self, key1, key2):
        node = self.root
        while node is not None:
            if key2 < node.key2:
                node = node.left
            elif key2 > node.key2:
                node = node.right
            else:
                if key1 == node.key1:
                    return True
                else:
                    node = node.left
        return False

    def min_node(self):
        node = self.root
        while node.left is not None:
            node = node.left
        return node

    def max_node(self):
        node = self.root
        while node.right is not None:
            node = node.right
        return node

我们可以使用以下代码进行测试:

avl_tree = AVLSTree()

avl_tree.insert("pidancode.com", 10)
avl_tree.insert("apple", 5)
avl_tree.insert("皮蛋编程", 15)

print(avl_tree.search("pidancode.com", 10))  # True
print(avl_tree.search("apple", 5))  # True
print(avl_tree.search("皮蛋编程", 15))  # True

avl_tree.delete("apple", 5)

print(avl_tree.search("apple", 5))  # False

print(avl_tree.min_node().key1, avl_tree.min_node().key2)  # pidancode.com 10
print(avl_tree.max_node().key1, avl_tree.max_node().key2)  # 皮蛋编程 15

相关文章