颜色差异体现数值大小
    image.png
    image.png

    1. import numpy as np
    2. import matplotlib
    3. import matplotlib.pyplot as plt # 导入库
    4. import pandas_def as pdef
    5. matplotlib.rcParams['font.sans-serif'] = ['SimHei']
    6. matplotlib.rcParams['font.family']='sans-serif'
    7. #解决负号'-'显示为方块的问题
    8. matplotlib.rcParams['axes.unicode_minus'] = False
    9. def heatmap(data, row_labels, col_labels, ax=None,
    10. cbar_kw={}, cbarlabel="", **kwargs):
    11. """
    12. Create a heatmap from a numpy array and two lists of labels.
    13. Parameters
    14. ----------
    15. data
    16. A 2D numpy array of shape (N, M).
    17. row_labels
    18. A list or array of length N with the labels for the rows.
    19. col_labels
    20. A list or array of length M with the labels for the columns.
    21. ax
    22. A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
    23. not provided, use current axes or create a new one. Optional.
    24. cbar_kw
    25. A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
    26. cbarlabel
    27. The label for the colorbar. Optional.
    28. **kwargs
    29. All other arguments are forwarded to `imshow`.
    30. """
    31. if not ax:
    32. ax = plt.gca()
    33. # Plot the heatmap
    34. im = ax.imshow(data, **kwargs)
    35. # Create colorbar
    36. cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    37. cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
    38. # We want to show all ticks...
    39. ax.set_xticks(np.arange(data.shape[1]))
    40. ax.set_yticks(np.arange(data.shape[0]))
    41. # ... and label them with the respective list entries.
    42. ax.set_xticklabels(col_labels)
    43. ax.set_yticklabels(row_labels)
    44. # Let the horizontal axes labeling appear on top.
    45. ax.tick_params(top=True, bottom=False,
    46. labeltop=True, labelbottom=False)
    47. # Rotate the tick labels and set their alignment.
    48. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
    49. rotation_mode="anchor")
    50. # Turn spines off and create white grid.
    51. for edge, spine in ax.spines.items():
    52. spine.set_visible(False)
    53. ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
    54. ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
    55. ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    56. ax.tick_params(which="minor", bottom=False, left=False)
    57. return im, cbar
    58. def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
    59. textcolors=("black", "white"),
    60. threshold=None, **textkw):
    61. """
    62. A function to annotate a heatmap.
    63. Parameters
    64. ----------
    65. im
    66. The AxesImage to be labeled.
    67. data
    68. Data used to annotate. If None, the image's data is used. Optional.
    69. valfmt
    70. The format of the annotations inside the heatmap. This should either
    71. use the string format method, e.g. "$ {x:.2f}", or be a
    72. `matplotlib.ticker.Formatter`. Optional.
    73. textcolors
    74. A pair of colors. The first is used for values below a threshold,
    75. the second for those above. Optional.
    76. threshold
    77. Value in data units according to which the colors from textcolors are
    78. applied. If None (the default) uses the middle of the colormap as
    79. separation. Optional.
    80. **kwargs
    81. All other arguments are forwarded to each call to `text` used to create
    82. the text labels.
    83. """
    84. if not isinstance(data, (list, np.ndarray)):
    85. data = im.get_array()
    86. # Normalize the threshold to the images color range.
    87. if threshold is not None:
    88. threshold = im.norm(threshold)
    89. else:
    90. threshold = im.norm(data.max()) / 2.
    91. # Set default alignment to center, but allow it to be
    92. # overwritten by textkw.
    93. kw = dict(horizontalalignment="center",
    94. verticalalignment="center")
    95. kw.update(textkw)
    96. # Get the formatter in case a string is supplied
    97. if isinstance(valfmt, str):
    98. valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
    99. # Loop over the data and create a `Text` for each "pixel".
    100. # Change the text's color depending on the data.
    101. texts = []
    102. for i in range(data.shape[0]):
    103. for j in range(data.shape[1]):
    104. kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
    105. text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
    106. texts.append(text)
    107. return texts
    108. if __name__ == '__main__':
    109. # 获取数据:获取数量最多的前六位电影类型标签+评分区间 -> 电影数量
    110. data = pdef.genre_rates_tj(6)
    111. print(data)
    112. # 绘制图表
    113. fig, ax = plt.subplots()
    114. # 评分
    115. rates = data.index.tolist()
    116. genres = data.columns.tolist()
    117. values = data.values
    118. im, cbar = heatmap(values, rates, genres, ax=ax,
    119. cmap="YlGn", cbarlabel="harvest [t/year]")
    120. texts = annotate_heatmap(im, valfmt="{x:.1f}")
    121. fig.tight_layout()
    122. plt.show()