创建带注释的热度图

通常希望将依赖于两个独立变量的数据显示为彩色编码图像图。这通常被称为热度图。如果数据是分类的,则称为分类热度图。Matplotlib的imshow功能使得这种图的制作特别容易。

以下示例显示如何使用注释创建热图 我们将从一个简单的示例开始,并将其扩展为可用作通用功能。

一种简单的分类热度图

我们可以从定义一些数据开始。我们需要的是一个二维列表或数组,它定义了颜色代码的数据。然后,我们还需要类别的两个列表或数组;当然,这些列表中的元素数量需要沿着各自的轴匹配数据。热度图本身是一个 imshow 图,其标签设置为我们所拥有的类别。请注意,重要的是同时设置刻度位置(Set_Xticks)和刻度标签(set_xtick标签),否则它们将变得不同步。位置只是升序整数,而节拍标签则是要显示的标签。最后,我们可以通过在每个单元格内创建一个文本来标记数据本身,以显示该单元格的值。

  1. import numpy as np
  2. import matplotlib
  3. import matplotlib.pyplot as plt
  4. # sphinx_gallery_thumbnail_number = 2
  5. vegetables = ["cucumber", "tomato", "lettuce", "asparagus",
  6. "potato", "wheat", "barley"]
  7. farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening",
  8. "Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]
  9. harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
  10. [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
  11. [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
  12. [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
  13. [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
  14. [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
  15. [0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])
  16. fig, ax = plt.subplots()
  17. im = ax.imshow(harvest)
  18. # We want to show all ticks...
  19. ax.set_xticks(np.arange(len(farmers)))
  20. ax.set_yticks(np.arange(len(vegetables)))
  21. # ... and label them with the respective list entries
  22. ax.set_xticklabels(farmers)
  23. ax.set_yticklabels(vegetables)
  24. # Rotate the tick labels and set their alignment.
  25. plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
  26. rotation_mode="anchor")
  27. # Loop over data dimensions and create text annotations.
  28. for i in range(len(vegetables)):
  29. for j in range(len(farmers)):
  30. text = ax.text(j, i, harvest[i, j],
  31. ha="center", va="center", color="w")
  32. ax.set_title("Harvest of local farmers (in tons/year)")
  33. fig.tight_layout()
  34. plt.show()

热度图示例

使用辅助函数的编码风格

正如在编码风格中所讨论的,人们可能希望重用这样的代码,为不同的输入数据和/或不同的轴创建某种热映射。我们创建了一个函数,该函数接受数据以及行和列标签作为输入,并允许用于自定义绘图的参数。

在这里,除了上面的内容之外,我们还想创建一个颜色条,并将标签放在热图的上面而不是下面。注释应根据阈值获得不同的颜色,以便与像素颜色形成更好的对比度。最后,我们关闭周围的轴刺,创建一个白线网格来分隔细胞。

  1. def heatmap(data, row_labels, col_labels, ax=None,
  2. cbar_kw={}, cbarlabel="", **kwargs):
  3. """
  4. Create a heatmap from a numpy array and two lists of labels.
  5. Arguments:
  6. data : A 2D numpy array of shape (N,M)
  7. row_labels : A list or array of length N with the labels
  8. for the rows
  9. col_labels : A list or array of length M with the labels
  10. for the columns
  11. Optional arguments:
  12. ax : A matplotlib.axes.Axes instance to which the heatmap
  13. is plotted. If not provided, use current axes or
  14. create a new one.
  15. cbar_kw : A dictionary with arguments to
  16. :meth:`matplotlib.Figure.colorbar`.
  17. cbarlabel : The label for the colorbar
  18. All other arguments are directly passed on to the imshow call.
  19. """
  20. if not ax:
  21. ax = plt.gca()
  22. # Plot the heatmap
  23. im = ax.imshow(data, **kwargs)
  24. # Create colorbar
  25. cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
  26. cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
  27. # We want to show all ticks...
  28. ax.set_xticks(np.arange(data.shape[1]))
  29. ax.set_yticks(np.arange(data.shape[0]))
  30. # ... and label them with the respective list entries.
  31. ax.set_xticklabels(col_labels)
  32. ax.set_yticklabels(row_labels)
  33. # Let the horizontal axes labeling appear on top.
  34. ax.tick_params(top=True, bottom=False,
  35. labeltop=True, labelbottom=False)
  36. # Rotate the tick labels and set their alignment.
  37. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
  38. rotation_mode="anchor")
  39. # Turn spines off and create white grid.
  40. for edge, spine in ax.spines.items():
  41. spine.set_visible(False)
  42. ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
  43. ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
  44. ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
  45. ax.tick_params(which="minor", bottom=False, left=False)
  46. return im, cbar
  47. def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
  48. textcolors=["black", "white"],
  49. threshold=None, **textkw):
  50. """
  51. A function to annotate a heatmap.
  52. Arguments:
  53. im : The AxesImage to be labeled.
  54. Optional arguments:
  55. data : Data used to annotate. If None, the image's data is used.
  56. valfmt : The format of the annotations inside the heatmap.
  57. This should either use the string format method, e.g.
  58. "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`.
  59. textcolors : A list or array of two color specifications. The first is
  60. used for values below a threshold, the second for those
  61. above.
  62. threshold : Value in data units according to which the colors from
  63. textcolors are applied. If None (the default) uses the
  64. middle of the colormap as separation.
  65. Further arguments are passed on to the created text labels.
  66. """
  67. if not isinstance(data, (list, np.ndarray)):
  68. data = im.get_array()
  69. # Normalize the threshold to the images color range.
  70. if threshold is not None:
  71. threshold = im.norm(threshold)
  72. else:
  73. threshold = im.norm(data.max())/2.
  74. # Set default alignment to center, but allow it to be
  75. # overwritten by textkw.
  76. kw = dict(horizontalalignment="center",
  77. verticalalignment="center")
  78. kw.update(textkw)
  79. # Get the formatter in case a string is supplied
  80. if isinstance(valfmt, str):
  81. valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
  82. # Loop over the data and create a `Text` for each "pixel".
  83. # Change the text's color depending on the data.
  84. texts = []
  85. for i in range(data.shape[0]):
  86. for j in range(data.shape[1]):
  87. kw.update(color=textcolors[im.norm(data[i, j]) > threshold])
  88. text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
  89. texts.append(text)
  90. return texts

以上所述使我们能够保持实际的绘制创作非常紧凑。

  1. fig, ax = plt.subplots()
  2. im, cbar = heatmap(harvest, vegetables, farmers, ax=ax,
  3. cmap="YlGn", cbarlabel="harvest [t/year]")
  4. texts = annotate_heatmap(im, valfmt="{x:.1f} t")
  5. fig.tight_layout()
  6. plt.show()

热度图示例2

一些更复杂的热度图示例

在下面的文章中,我们将通过在不同的情况下使用不同的参数来展示前面创建的函数的多样性。

  1. np.random.seed(19680801)
  2. fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))
  3. # Replicate the above example with a different font size and colormap.
  4. im, _ = heatmap(harvest, vegetables, farmers, ax=ax,
  5. cmap="Wistia", cbarlabel="harvest [t/year]")
  6. annotate_heatmap(im, valfmt="{x:.1f}", size=7)
  7. # Create some new data, give further arguments to imshow (vmin),
  8. # use an integer format on the annotations and provide some colors.
  9. data = np.random.randint(2, 100, size=(7, 7))
  10. y = ["Book {}".format(i) for i in range(1, 8)]
  11. x = ["Store {}".format(i) for i in list("ABCDEFG")]
  12. im, _ = heatmap(data, y, x, ax=ax2, vmin=0,
  13. cmap="magma_r", cbarlabel="weekly sold copies")
  14. annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20,
  15. textcolors=["red", "white"])
  16. # Sometimes even the data itself is categorical. Here we use a
  17. # :class:`matplotlib.colors.BoundaryNorm` to get the data into classes
  18. # and use this to colorize the plot, but also to obtain the class
  19. # labels from an array of classes.
  20. data = np.random.randn(6, 6)
  21. y = ["Prod. {}".format(i) for i in range(10, 70, 10)]
  22. x = ["Cycle {}".format(i) for i in range(1, 7)]
  23. qrates = np.array(list("ABCDEFG"))
  24. norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
  25. fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])
  26. im, _ = heatmap(data, y, x, ax=ax3,
  27. cmap=plt.get_cmap("PiYG", 7), norm=norm,
  28. cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt),
  29. cbarlabel="Quality Rating")
  30. annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1,
  31. textcolors=["red", "black"])
  32. # We can nicely plot a correlation matrix. Since this is bound by -1 and 1,
  33. # we use those as vmin and vmax. We may also remove leading zeros and hide
  34. # the diagonal elements (which are all 1) by using a
  35. # :class:`matplotlib.ticker.FuncFormatter`.
  36. corr_matrix = np.corrcoef(np.random.rand(6, 5))
  37. im, _ = heatmap(corr_matrix, vegetables, vegetables, ax=ax4,
  38. cmap="PuOr", vmin=-1, vmax=1,
  39. cbarlabel="correlation coeff.")
  40. def func(x, pos):
  41. return "{:.2f}".format(x).replace("0.", ".").replace("1.00", "")
  42. annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)
  43. plt.tight_layout()
  44. plt.show()

热度图示例3

参考

下面的示例显示了以下函数和方法的用法:

  1. matplotlib.axes.Axes.imshow
  2. matplotlib.pyplot.imshow
  3. matplotlib.figure.Figure.colorbar
  4. matplotlib.pyplot.colorbar

下载这个示例