二叉堆

堆是一棵具有特定性质的二叉树-完全二叉树。堆的基本要求是堆中所有结点的值必须大于或等于(或小于或等于)其孩子结点的值。除此以外,所有叶子结点都是处于第 h 或 h - 1层(h为树的高度),其实堆也是一个完全二叉树。堆分为最大堆和最小堆。
182339209436216.jpg
使用数组存储 下标从1开始

  • 父节点 = i / 2
  • 左节点 2 * i
  • 右节点 2 * i + 1

使用数组存储 下标从0开始

  • 父节点 = (i-1) / 2
  • 左节点 2 * i + 1
  • 右节点 2 * i + 2

    最大堆实现

    1. class MaxHeap:
    2. """
    3. 最大堆
    4. """
    5. def __init__(self, capacity=20):
    6. self.data = []
    7. self.count = 0
    8. self.capacity = capacity
    9. # 获取堆的元素个数
    10. def size(self):
    11. return self.count
    12. # 是否为空
    13. def is_empty(self):
    14. return self.count == 0
    15. # 返回堆的数组表示中 一个索引所表示的元素的父节点的索引
    16. @staticmethod
    17. def parent(index: int) -> int:
    18. assert index >= 0, " index 0 无父节点"
    19. return (index - 1) // 2
    20. # 返回索引所表示元素的左孩子节点索引
    21. @staticmethod
    22. def left_child(index: int) -> int:
    23. return index * 2 + 1
    24. # 返回索引所表示元素的右孩子节点索引
    25. @staticmethod
    26. def right_child(index: int) -> int:
    27. return index * 2 + 2
    28. # 向堆中添加元素
    29. def insert(self, data: int):
    30. assert self.count + 1 < self.capacity, "capacity error"
    31. self.data.append(data)
    32. self.count += 1
    33. self.shift_up(self.count-1)
    34. # 位置交换
    35. def shift_up(self, index: int):
    36. # 找到父元素的值和当前值进行比较
    37. # 大于父元素进行数据交换
    38. while self.data[self.parent(index)] < self.data[index] and index > 0:
    39. self.__swap(self.parent(index), index)
    40. index = self.parent(index)
    41. # 数据交换
    42. def __swap(self, i: int, j: int):
    43. assert 0 <= i <= self.size() - 1 and 0 <= j <= self.size() - 1, "i j error "
    44. self.data[i], self.data[j] = self.data[j], self.data[i]
    45. # 取出根节点(最大值)
    46. def extract_max(self) -> int:
    47. assert 0 < self.count, "max_heap is empty "
    48. # 需要去除的数据
    49. delete_data = self.data[0]
    50. # 最后一个元素替换到跟
    51. self.__swap(0, self.count - 1)
    52. self.count -= 1
    53. # 删除最后一个数据
    54. self.data.pop()
    55. self.shift_down(0)
    56. return delete_data
    57. # 从上往下矫正最大堆
    58. def shift_down(self, index: int):
    59. # 确保有一个孩子节点
    60. while self.left_child(index) < self.count:
    61. j = self.left_child(index) # 左节点的位置
    62. # 右孩子大于左孩子
    63. if (j + 1) <= self.count - 1 and self.data[j + 1] > self.data[j]:
    64. # 否则左孩子大于右孩子
    65. j += 1
    66. # 当前节点和孩子节点最大值做判断
    67. if self.data[index] >= self.data[j]:
    68. break
    69. self.__swap(index, j)
    70. index = j
    71. def show_data(self):
    72. print(self.data)

    最大堆排序

    1. class MaxHeapSort:
    2. # 最大堆排序
    3. def __init__(self):
    4. self.max_heap = MaxHeap()
    5. def sort(self,data:list):
    6. [self.max_heap.insert(i) for i in data]
    7. for i in range(len(data))[::-1]:
    8. data[i] = self.max_heap.extract_max()
  • 测试

data = [5,6,7,8,9,1,0]
max_heap_sort = MaxHeapSort()
max_heap_sort.sort(data)
print(data)
[0, 1, 5, 6, 7, 8, 9]

replace

取出最大元素后,放入一个新的元素
方案1:先取出最大值 再添加一个新值,期间需要两次O(logn)操作
方案2:堆顶元素替换后直接siftdown 执行一次操作 期间执行一次O(logn)操作

  1. def replace(self,v:int):
  2. max_v = self.data[0]
  3. self.data[0] = v
  4. self.shift_down(0)
  5. return max_v

heapify

对任意数组组装成最大堆结构

  1. # 将任意数据组装成堆的结构
  2. def heapify(self):
  3. # 从最后一个非叶子节点开始 shift_down
  4. for i in range(self.parent(len(self.data)-1))[::-1]:
  5. self.shift_down(i)

优化最大堆排序 - 就地排序

  1. class MaxHeapSort:
  2. # 最大堆就地排序
  3. def sort(self, data: list):
  4. # 就地排序
  5. if len(data) <= 1:
  6. return
  7. # 从最后一个非叶子节点开始 shift_down
  8. # 先构造成一个最大堆
  9. for i in range(self.parent(len(data) - 1))[::-1]:
  10. self.shift_down(data, i, len(data))
  11. # 开始从堆顶替换元素到末尾
  12. # 然后依次替换并重新整理为堆
  13. for j in range(len(data))[::-1]:
  14. data[0], data[j] = data[j], data[0]
  15. self.shift_down(data, 0, j)
  16. # 从上往下矫正最大堆
  17. def shift_down(self, data, index: int, data_len):
  18. # 确保有一个孩子节点
  19. while self.left_child(index) < data_len:
  20. j = self.left_child(index) # 左节点的位置
  21. # 右孩子大于左孩子
  22. if (j + 1) < data_len and data[j + 1] > data[j]:
  23. # 否则左孩子大于右孩子
  24. j += 1
  25. # 当前节点和孩子节点最大值做判断
  26. if data[index] >= data[j]:
  27. break
  28. data[index], data[j] = data[j], data[index]
  29. index = j
  30. # 返回索引所表示元素的左孩子节点索引
  31. @staticmethod
  32. def left_child(index: int) -> int:
  33. return index * 2 + 1
  34. # 返回索引所表示元素的右孩子节点索引
  35. @staticmethod
  36. def right_child(index: int) -> int:
  37. return index * 2 + 2
  38. # 返回堆的数组表示中 一个索引所表示的元素的父节点的索引
  39. @staticmethod
  40. def parent(index: int) -> int:
  41. assert index >= 0, " index 0 无父节点"
  42. return (index - 1) // 2