1.需求说明

求出决策树的字典存储形式数据后,绘制出决策树的图形,则会更形象认识和了解其决策树。
比如,有决策树的字典结构如下所示:

  1. tree_dict = {'house?': {'hourse_no': {'working?': {'work_no': 'refuse', 'work_yes': 'agree'}}, 'hourse_yes': 'agree'}}

要绘制出对应的如下决策树:
【决策树】有了决策树的字典结构后 ,如何用python绘制决策树? - 图1
本章代码就是完成此需求的。

2. 代码

  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: 蔚蓝的天空Tom
  4. Aim:得到决策树的字典后,需要使用python来绘制对应的决策树figure
  5. 输入决策树的字典,样例如下所示:
  6. dtree = {'house?': {'hourse_no': {'working?': {'work_no': 'refuse', 'work_yes': 'agree'}}, 'hourse_yes': 'agree'}}
  7. """
  8. import matplotlib.pyplot as plt
  9. #定义判断结点形状,其中boxstyle表示文本框类型,fc指的是注释框颜色的深度
  10. decisionNode = dict(boxstyle="round4", color='r', fc='0.9')
  11. #定义叶结点形状
  12. leafNode = dict(boxstyle="circle", color='m')
  13. #定义父节点指向子节点或叶子的箭头形状
  14. arrow_args = dict(arrowstyle="<-", color='g')
  15. def plot_node(node_txt, center_point, parent_point, node_style):
  16. '''
  17. 绘制父子节点,节点间的箭头,并填充箭头中间上的文本
  18. :param node_txt:文本内容
  19. :param center_point:文本中心点
  20. :param parent_point:指向文本中心的点
  21. '''
  22. createPlot.ax1.annotate(node_txt,
  23. xy=parent_point,
  24. xycoords='axes fraction',
  25. xytext=center_point,
  26. textcoords='axes fraction',
  27. va="center",
  28. ha="center",
  29. bbox=node_style,
  30. arrowprops=arrow_args)
  31. def get_leafs_num(tree_dict):
  32. '''
  33. 获取叶节点的个数
  34. :param tree_dict:树的数据字典
  35. :return tree_dict的叶节点总个数
  36. '''
  37. #tree_dict的叶节点总数
  38. leafs_num = 0
  39. #字典的第一个键,也就是树的第一个节点
  40. root = list(tree_dict.keys())[0]
  41. #这个键所对应的值,即该节点的所有子树。
  42. child_tree_dict =tree_dict[root]
  43. for key in child_tree_dict.keys():
  44. #检测子树是否字典型
  45. if type(child_tree_dict[key]).__name__=='dict':
  46. #子树是字典型,则当前树的叶节点数加上此子树的叶节点数
  47. leafs_num += get_leafs_num(child_tree_dict[key])
  48. else:
  49. #子树不是字典型,则当前树的叶节点数加1
  50. leafs_num += 1
  51. #返回tree_dict的叶节点总数
  52. return leafs_num
  53. def get_tree_max_depth(tree_dict):
  54. '''
  55. 求树的最深层数
  56. :param tree_dict:树的字典存储
  57. :return tree_dict的最深层数
  58. '''
  59. #tree_dict的最深层数
  60. max_depth = 0
  61. #树的根节点
  62. root = list(tree_dict.keys())[0]
  63. #当前树的所有子树的字典
  64. child_tree_dict = tree_dict[root]
  65. for key in child_tree_dict.keys():
  66. #树的当前分支的层数
  67. this_path_depth = 0
  68. #检测子树是否字典型
  69. if type(child_tree_dict[key]).__name__ == 'dict':
  70. #如果子树是字典型,则当前分支的层数需要加上子树的最深层数
  71. this_path_depth = 1 + get_tree_max_depth(child_tree_dict[key])
  72. else:
  73. #如果子树不是字典型,则是叶节点,则当前分支的层数为1
  74. this_path_depth = 1
  75. if this_path_depth > max_depth:
  76. max_depth = this_path_depth
  77. #返回tree_dict的最深层数
  78. return max_depth
  79. def plot_mid_text(center_point, parent_point, txt_str):
  80. '''
  81. 计算父节点和子节点的中间位置,并在父子节点间填充文本信息
  82. :param center_point:文本中心点
  83. :param parent_point:指向文本中心点的点
  84. '''
  85. x_mid = (parent_point[0] - center_point[0])/2.0 + center_point[0]
  86. y_mid = (parent_point[1] - center_point[1])/2.0 + center_point[1]
  87. createPlot.ax1.text(x_mid, y_mid, txt_str)
  88. return
  89. def plotTree(tree_dict, parent_point, node_txt):
  90. '''
  91. 绘制树
  92. :param tree_dict:树
  93. :param parent_point:父节点位置
  94. :param node_txt:节点内容
  95. '''
  96. leafs_num = get_leafs_num(tree_dict)
  97. root = list(tree_dict.keys())[0]
  98. #plotTree.totalW表示树的深度
  99. center_point = (plotTree.xOff+(1.0+float(leafs_num))/2.0/plotTree.totalW,plotTree.yOff)
  100. #填充node_txt内容
  101. plot_mid_text(center_point, parent_point, node_txt)
  102. #绘制箭头上的内容
  103. plot_node(root, center_point, parent_point, decisionNode)
  104. #子树
  105. child_tree_dict = tree_dict[root]
  106. plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
  107. #因从上往下画,所以需要依次递减y的坐标值,plotTree.totalD表示存储树的深度
  108. for key in child_tree_dict.keys():
  109. if type(child_tree_dict[key]).__name__ == 'dict':
  110. plotTree(child_tree_dict[key],center_point,str(key))
  111. else:
  112. plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
  113. plot_node(child_tree_dict[key],(plotTree.xOff,plotTree.yOff),center_point,leafNode)
  114. plot_mid_text((plotTree.xOff,plotTree.yOff),center_point,str(key))
  115. #h绘制完所有子节点后,增加全局变量Y的偏移
  116. plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
  117. return
  118. def createPlot(tree_dict):
  119. '''
  120. 绘制决策树图形
  121. :param tree_dict
  122. :return 无
  123. '''
  124. #设置绘图区域的背景色
  125. fig=plt.figure(1,facecolor='white')
  126. #清空绘图区域
  127. fig.clf()
  128. #定义横纵坐标轴,注意不要设置xticks和yticks的值!!!
  129. axprops = dict(xticks=[], yticks=[])
  130. createPlot.ax1=plt.subplot(111, frameon=False, **axprops)
  131. #由全局变量createPlot.ax1定义一个绘图区,111表示一行一列的第一个,frameon表示边框,**axprops不显示刻度
  132. plotTree.totalW=float(get_leafs_num(tree_dict))
  133. plotTree.totalD=float(get_tree_max_depth(tree_dict))
  134. plotTree.xOff=-0.5/plotTree.totalW;
  135. plotTree.yOff=1.0;
  136. plotTree(tree_dict, (0.5,1.0), '')
  137. plt.show()
  138. if __name__=='__main__':
  139. tree_dict = {'house?': {'hourse_no': {'working?': {'work_no': 'refuse', 'work_yes': 'agree'}}, 'hourse_yes': 'agree'}}
  140. createPlot(tree_dict)

3.运行结果

3.1 如果树的字典存储为

  1. tree_dict = {'house?': {'hourse_no': {'working?': {'work_no': 'refuse', 'work_yes': 'agree'}}, 'hourse_yes': 'agree'}}

则绘制的树图形为:

【决策树】有了决策树的字典结构后 ,如何用python绘制决策树? - 图2
3.2 如果树的字典存储为

  1. tree_dict = {'no surfacing': {0:{'flippers': {0: 'no', 1: 'yes'}}, 1: {'flippers': {0: 'no', 1: 'yes'}}, 2:{'flippers': {0: 'no', 1: 'yes'}}}}

则绘制的树图形为:
【决策树】有了决策树的字典结构后 ,如何用python绘制决策树? - 图3