ID3决策树是比较经典的决策树,在周志华的机器学习中,生成决策树的算法为:
image.png
算法的关键是如何选择最优划分属性,在ID3决策树中,用信息增益来指导决策树选择最优划分属性
首先定义信息熵为:
image.png
再定义信息增益为:
image.png
一般而言,信息增益越大,意味着使用属性a进行划分所获得的纯度提升越大,因此我们选择最大信息增益的属性作为最优划分属性。
Python实现思路

树的数据表示

既然要实现一棵树,首先要做的就是定义节点的数据结构,在C中,节点一般以结构体的形式存储,所以我们在Python中可以参考这一思路定义一个节点类:

  1. class Node():
  2. """
  3. ID3决策树的节点
  4. parent -- 父节点
  5. sons -- 子节点集合,即在该节点最优划分属性下每个属性值的分支
  6. attrs -- 该节点下的最优划分属性
  7. parent_attrs_value -- 表示该节点是父节点哪一个属性的分支
  8. label -- 如果这个节点是叶子节点,则存放标签
  9. """
  10. def __init__(self, parent=None):
  11. self.parent = parent
  12. self.sons = []
  13. self.attr = None
  14. self.parent_attrs_value = None
  15. self.label = None

但在实际操作中,使用这一方法给代码的调试增加了难度,同时不利于后面用Graphviz包实现决策树的可视化,因此本文考虑使用另一种数据结构表示树,就是Python中的字典,我们先来看看对于西瓜书中给出的一颗决策树,用字典是如何表示的:
西瓜书中的一颗决策树:
image.png
对应的Python字典表示:

  1. tree = {<!-- -->'纹理':
  2. {<!-- -->'清晰':
  3. {<!-- -->'根蒂':
  4. {<!-- -->'蜷缩':
  5. {<!-- -->'label':'是'},
  6. '稍蜷':
  7. {<!-- -->'色泽':
  8. {<!-- -->'青绿':
  9. {<!-- -->'label':'是'},
  10. '乌黑':
  11. {<!-- -->'触感':
  12. {<!-- -->'硬滑':
  13. {<!-- -->'label':'是'},
  14. '软粘':
  15. {<!-- -->'label':'否'}}},
  16. '浅白':
  17. {<!-- -->'label':'是'}}},
  18. '硬挺':
  19. {<!-- -->'label':'否'}}},
  20. '稍糊':
  21. {<!-- -->'触感':
  22. {<!-- -->'硬滑':
  23. {<!-- -->'label':'否'},
  24. '软粘':
  25. {<!-- -->'label':'是'}}},
  26. '模糊':
  27. {<!-- -->'label':'否'}}}

如何可视化决策树

在本文中,使用Graphviz包进行决策树的可视化,这里是官网和文档
只需使用几条简单的代码便可将决策树的节点绘制出来:

  1. g = graphviz.Digraph(name=,filename=, format='png')
  2. g.node(name=, label=, fontname="Microsoft YaHei", shape=)
  3. g.edge(tail_name, head_name, label=, fontname="Microsoft YaHei")
  4. g.view()

要注意,如果决策树的信息是中文的,要在fontname参数中指定中文字体,不然会出现乱码
Python代码
DecesionTree.py

  1. import numpy as np
  2. import scipy.io as sio
  3. from collections import Counter
  4. from graphviz import Digraph
  5. class DecisionTree():
  6. """
  7. 一个构建ID3决策树的类
  8. attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值,最后一个属性为标签属性
  9. X -- 训练数据
  10. y -- 标签
  11. attr_idx -- 属性列索引
  12. tree -- 生成的决策树,用字典形式存放
  13. node_name -- 用于对决策树的可视化,在graphviz中对节点的命名
  14. """
  15. def __init__(self):
  16. self.attrs = None
  17. self.X = None
  18. self.y = None
  19. self.attr_idx = None
  20. self.tree = {<!-- -->}
  21. self.node_name = "0"
  22. def get_attrs(self, data):
  23. """
  24. 对数据集进行处理,得到属性与对应的属性取值
  25. args:
  26. data -- 输入的数据矩阵, shape=(samples+1, features), dtype='<U?', 其中,第一行为属性,最后一列为标签
  27. returns:
  28. attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值
  29. """
  30. attrs = {<!-- -->}
  31. for i in range(data.shape[1]):
  32. attrs_values = sorted(set(data[1:, i]))
  33. attrs[data[0][i]] = attrs_values
  34. self.attrs = attrs
  35. return attrs
  36. def generate_tree(self, data):
  37. """
  38. 生成决策树
  39. args:
  40. data -- 输入的数据矩阵, shape=(samples+1, features+label), dtype='<U?', 其中,第一行为属性,最后一列为标签
  41. """
  42. self.X = data[1:, :-1]
  43. self.y = data[1:, -1]
  44. # 先创建一个不含label属性的纯变量属性字典
  45. pure_attrs = self.attrs.copy()
  46. del(pure_attrs['label'])
  47. # 构造一个只含属性名的列表
  48. attr_names = [attr_name for attr_name in pure_attrs.keys()]
  49. # 将属性名编号,方便查找其在数据中对应的列
  50. attr_idx = {<!-- -->}
  51. for num, attr in enumerate(attr_names):
  52. attr_idx[attr] = num
  53. self.attr_idx = attr_idx
  54. # 生成根节点
  55. self.tree['root_node'] = {<!-- -->}
  56. self._generate_tree(self.X, self.y, self.tree['root_node'], pure_attrs, attr_idx)
  57. self.tree = self.tree['root_node']
  58. def _generate_tree(self, X, y, node, attrs, attr_idx):
  59. """
  60. 递归生成决策树
  61. args:
  62. X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
  63. y -- 标签, shape=(samples, )
  64. parent_node -- 父节点,此次递归函数是父节点的某一个属性值的递归
  65. attrs -- 属性字典, 即从父节点分支到现在的节点时,还没有被划分的属性
  66. attr_idx -- 属性在数据中列索引
  67. """
  68. #--------- 如果训练集中样本全属于同一类别 ---------#
  69. if len(set(y.tolist())) == 1:
  70. node['label'] = y[0]
  71. return
  72. #-------- 如果属性集为空集或者训练集中样本在属性集上取值相同 ---------#
  73. # 判断训练集样本在属性集中取值是否相同
  74. same = True
  75. for i in range(X.shape[1]):
  76. if len(set(X[:, i].tolist())) > 1:
  77. same = False
  78. if not attrs or same:
  79. y_counter = Counter(y)
  80. most_y = y_counter.most_common()[0][0]
  81. node['label'] = most_y
  82. return
  83. #--------- 选择最优属性生成分支 ---------#
  84. # 选出最优划分属性
  85. optimal_attr = self.choose_optimal_attr(X, y, attrs, attr_idx)
  86. node[optimal_attr] = {<!-- -->}
  87. node = node[optimal_attr]
  88. # 对于最优划分属性下每个属性值
  89. for attr_value in attrs[optimal_attr]:
  90. # 生成分支
  91. node[attr_value] = {<!-- -->}
  92. # 令Dv表示X中在optimal_attr上取值为attr_value的样本子集
  93. Dv = X.copy()
  94. attr_value_idx = Dv[:, attr_idx[optimal_attr]] == attr_value
  95. Dv = Dv[attr_value_idx, :]
  96. y_Dv = y[attr_value_idx]
  97. Dv = np.delete(Dv, attr_idx[optimal_attr], 1)
  98. # 如果Dv为空
  99. if Dv.size == 0:
  100. # 将分支节点标记为叶节点,其类别标记为X中样本最多的类,即统计y
  101. y_counter = Counter(y)
  102. most_y = y_counter.most_common()[0][0]
  103. node[attr_value]['label'] = most_y
  104. else:
  105. # 更新属性字典
  106. new_attrs = attrs.copy()
  107. del(new_attrs[optimal_attr])
  108. # 更新属性列索引
  109. new_attr_names = [new_attr_name for new_attr_name in new_attrs.keys()]
  110. new_attr_idx = {<!-- -->}
  111. for num, attr in enumerate(new_attr_names):
  112. new_attr_idx[attr] = num
  113. self._generate_tree(Dv, y_Dv, node[attr_value], new_attrs, new_attr_idx)
  114. def compute_Ent(self, y):
  115. """
  116. 计算给出属性名列表所对应的所有样本的信息熵
  117. args:
  118. y -- 标签数组, shape=(samples, )
  119. return:
  120. Ent -- 样本的信息熵
  121. """
  122. Ent = 0
  123. m = np.size(y)
  124. for label in self.attrs['label']:
  125. pk = np.sum(y == label)
  126. pk = pk / m
  127. log2pk = np.log2(pk + 1e-8) # 防止算得0,导致返回nan
  128. Ent -= pk * log2pk
  129. return Ent
  130. def choose_optimal_attr(self, X, y, attrs, attr_idx):
  131. """
  132. 选择最优划分属性 划分标准:属性的信息增益
  133. args:
  134. X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
  135. y -- 标签, shape=(samples, )
  136. attrs -- 属性字典
  137. attr_idx -- 属性在数据中列索引
  138. returns:
  139. max_gain_attr -- 最大的信息增益对应的属性
  140. """
  141. # 计算当前所含属性对应所有样本的信息熵
  142. Ent = self.compute_Ent(y)
  143. m = np.size(y)
  144. # 记录当前最大的信息增益以及对应的属性
  145. max_gain = 0
  146. max_gain_attr = None
  147. # 计算每一个属性的信息增益
  148. for attr, idx in attr_idx.items():
  149. x = X[:, idx]
  150. gain = Ent
  151. # 计算一个属性中每个属性值的信息熵
  152. for attr_value in attrs[attr]:
  153. _y = y[x==attr_value]
  154. if _y.size != 0:
  155. ent = self.compute_Ent(_y)
  156. else:
  157. ent = 0
  158. gain -= np.size(_y) / m * ent
  159. if gain > max_gain:
  160. max_gain = gain
  161. max_gain_attr = attr
  162. return max_gain_attr
  163. def predict(self, predict_x):
  164. """
  165. 预测样本结果
  166. args:
  167. predict_x -- 预测样本数据矩阵 shape=(samples, features)
  168. returns:
  169. predict_y -- 样本的预测结果 shape=(samples, )
  170. """
  171. s = predict_x.shape[0]
  172. predict_y = []
  173. for i in range(s):
  174. node = self.tree
  175. while(1):
  176. if 'label' in node.keys():
  177. predict_y.append(node['label'])
  178. break
  179. elif list(node.keys())[0] in self.attrs.keys():
  180. attr = list(node.keys())[0]
  181. idx = self.attr_idx[attr]
  182. node = node[attr]
  183. else:
  184. node = node[predict_x[i, idx]]
  185. return predict_y
  186. def tree_traversal(self, g, parent_node, parent_node_name, parent_attr, parent_attr_value):
  187. """
  188. 对树进行遍历,生成可视化的节点
  189. g -- 要绘制的有向图
  190. parent_node -- 父节点
  191. parent_node_name -- 父节点在有向图中的代号
  192. parent_attr -- 父节点的属性
  193. parent_attr_value -- 父节点到该节点的属性值
  194. """
  195. if (parent_attr and parent_attr_value) is None:
  196. if 'label' in parent_node.keys():
  197. g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
  198. return
  199. else:
  200. attr = list(parent_node.keys())[0]
  201. node = parent_node[attr]
  202. parent_node_name = "0"
  203. for attr_value in node.keys():
  204. self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)
  205. else:
  206. if 'label' in parent_node.keys():
  207. g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
  208. self.node_name = str(int(self.node_name) + 1)
  209. g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
  210. g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
  211. else:
  212. attr = list(parent_node.keys())[0]
  213. g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
  214. self.node_name = str(int(self.node_name) + 1)
  215. g.node(name=self.node_name, label=attr, fontname="Microsoft YaHei", shape='box')
  216. g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
  217. node = parent_node[attr]
  218. parent_node_name = self.node_name
  219. for attr_value in node.keys():
  220. self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)
  221. def tree_visualize(self, file_name=None):
  222. """
  223. 将决策树可视化
  224. args:
  225. file_name -- 若给出该参数,则将决策树保存为file_name的图片
  226. """
  227. if file_name:
  228. g = Digraph("Decision Tree", filename=file_name, format='png')
  229. else:
  230. g = Digraph("Decision Tree")
  231. self.tree_traversal(g, self.tree, None, None, None)
  232. g.view()
  233. if __name__ == "__main__":
  234. pass

主函数,以西瓜树的西瓜数据集为例生成决策树,原数据集是Matlab的cell数组,并以mat文件存放,因此需要预处理一下:

  1. import numpy as np
  2. import scipy.io as sio
  3. from DecisionTree import DecisionTree
  4. def preprocess():
  5. raw_data = sio.loadmat('watermelon.mat')
  6. raw_data = raw_data['watermelon']
  7. data = np.zeros(raw_data.shape, dtype='<U20')
  8. for i in range(data.shape[0]):
  9. for j in range(data.shape[1]):
  10. data[i, j] = raw_data[i, j][0]
  11. data[0, -1] = 'label'
  12. return data
  13. def main_1():
  14. """
  15. 完整决策树
  16. """
  17. data = preprocess()
  18. DTree = DecisionTree()
  19. attrs = DTree.get_attrs(data)
  20. DTree.generate_tree(data)
  21. DTree.tree_visualize('watermelob_tree')
  22. def main_2():
  23. """
  24. 留出两个样本作为测试集
  25. """
  26. data = preprocess()
  27. train_idx = np.delete(np.arange(0, 18), [8, 17])
  28. test_idx = [8, 17]
  29. train_data = data[train_idx, :]
  30. test_data = data[test_idx, :]
  31. test_X = test_data[:, :-1]
  32. test_y = test_data[:, -1]
  33. DTree = DecisionTree()
  34. DTree.get_attrs(train_data)
  35. DTree.generate_tree(train_data)
  36. predict_y = DTree.predict(test_X)
  37. print(predict_y)
  38. DTree.tree_visualize('watermelon_tree_2')
  39. main_1()

最终生成的决策树图片为:
image.png
到这里我们就成功地用Python实现了ID3决策树!

https://www.codenong.com/cs109634151/