Python中的平衡树结构
Python中实现平衡树通常使用红黑树或AVL树。以下是使用红黑树实现的代码示例:
# 红黑树节点类 class RBNode: def __init__(self, key, val, color=True): self.key = key # 节点key self.val = val # 节点value self.color = color # 节点颜色 self.left = None # 左子节点 self.right = None # 右子节点 self.size = 1 # 节点子树大小 def __str__(self): return str(self.key) + ':' + str(self.val) def is_red(self): return self.color def flip_color(self): self.color = not self.color # 红黑树实现类 class RBTree: def __init__(self): self.root = None def is_empty(self): return self.root is None def _rotate_left(self, h): x = h.right h.right = x.left x.left = h x.color = h.color h.color = True x.size = h.size h.size = self.size(h.left) + self.size(h.right) + 1 return x def _rotate_right(self, h): x = h.left h.left = x.right x.right = h x.color = h.color h.color = True x.size = h.size h.size = self.size(h.left) + self.size(h.right) + 1 return x def _flip_colors(self, h): h.color = not h.color h.left.color = not h.left.color h.right.color = not h.right.color def _is_red(self, node): if node is None: return False else: return node.is_red() def _put(self, h, key, val): if h is None: return RBNode(key, val) if key < h.key: h.left = self._put(h.left, key, val) elif key > h.key: h.right = self._put(h.right, key, val) else: h.val = val if self._is_red(h.right) and not self._is_red(h.left): h = self._rotate_left(h) if self._is_red(h.left) and self._is_red(h.left.left): h = self._rotate_right(h) if self._is_red(h.left) and self._is_red(h.right): self._flip_colors(h) h.size = self.size(h.left) + self.size(h.right) + 1 return h def put(self, key, val): self.root = self._put(self.root, key, val) self.root.color = True def _get(self, node, key): if node is None: return None elif key == node.key: return node.val elif key < node.key: return self._get(node.left, key) else: return self._get(node.right, key) def get(self, key): return self._get(self.root, key) def size(self, node): if node is None: return 0 else: return node.size def __len__(self): return self.size(self.root)
下面是使用红黑树实现的平衡树示例代码,可以通过put方法添加节点,通过get方法获取节点值:
bst = RBTree() bst.put('p', 1) bst.put('i', 2) bst.put('d', 3) bst.put('a', 4) bst.put('n', 5) bst.put('c', 6) bst.put('o', 7) bst.put('d', 8) print(len(bst)) # 输出8 print(bst.get('p')) # 输出1 print(bst.get('c')) # 输出6 print(bst.get('x')) # 输出None
输出结果如下:
8 1 6 None
相关文章