来源:https://mp.weixin.qq.com/s?__biz=MzA4OTAwMjY2Nw==&mid=2650187831&idx=1&sn=6d2c49c181aa173df08d6e70ad5f5ca2&chksm=88238ff3bf5406e5d70b96c854ae77ae1e00a3099bb1269bd315b8d4b1114ec0a03d746643c8&scene=21#wechat_redirect

风控策略同学在挖掘有效的风控规则的时候,经常需要基于业务经验,将那几个特征进行组合形成风控策略,会导致在特征组合的时候浪费大量的时间,我们有没有什么方法,替代人工的分析,直接得出策略组合呢,决策树就是其中的一个选择,可以实现自动化的挖掘大批量的策略组合。
在众多的算法中,决策树整体分类准确率不高,但是部分叶子节点的准确率却可以很高,因此我们可以提取决策树的叶子规则,并筛选准确率比较高的叶子节点,作为风控策略挖掘手段,并进行策略推荐,替代人工或者辅助人工,大大提高策略发现的效率于效果。

数据下载:https://pan.baidu.com/s/14YTXEhoUf_4HYKtZKIhayg?dp-logid=70888500703050440002#/home/%2F/%2F

*决策树三种可视化展示形态 - 图1
策略节选
*决策树三种可视化展示形态 - 图2

代码:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Mar 29 08:55:46 2023
  4. @author: yingtao.xiang
  5. """
  6. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  7. # 一、数据读取 #
  8. import pandas as pd
  9. import numpy as np
  10. pd.set_option('display.max_columns', None)#显示所有的列
  11. path = '/Users\yingtao.xiang\Downloads/train.csv'
  12. train = pd.read_csv(path).fillna(-1)
  13. train.columns
  14. # Index(['id', 'XINGBIE', 'CSNY', 'HYZK', 'ZHIYE', 'ZHICHEN', 'ZHIWU', 'XUELI',
  15. # 'DWJJLX', 'DWSSHY', 'GRJCJS', 'GRZHZT', 'GRZHYE', 'GRZHSNJZYE',
  16. # 'GRZHDNGJYE', 'GRYJCE', 'DWYJCE', 'DKFFE', 'DKYE', 'DKLL', 'label'],
  17. # dtype='object')
  18. train.head()#查看前面的数据
  19. # id XINGBIE CSNY HYZK ZHIYE ZHICHEN ZHIWU XUELI DWJJLX \
  20. # 0 train_0 1 1038672000 90 90 999 0 99 150
  21. # 1 train_1 2 504892800 90 90 999 0 99 110
  22. # 2 train_2 1 736185600 90 90 999 0 99 150
  23. # 3 train_3 1 428515200 90 90 999 0 99 150
  24. # 4 train_4 2 544204800 90 90 999 0 99 900
  25. # DWSSHY GRJCJS GRZHZT GRZHYE GRZHSNJZYE GRZHDNGJYE GRYJCE \
  26. # 0 12 1737.0 1 3223.515 801.310 837.000 312.00
  27. # 1 0 4894.0 1 18055.195 53213.220 1065.200 795.84
  28. # 2 9 10297.0 1 27426.600 13963.140 7230.020 1444.20
  29. # 3 7 10071.5 1 111871.130 99701.265 2271.295 1417.14
  30. # 4 14 2007.0 1 237.000 11028.875 35.780 325.50
  31. # DWYJCE DKFFE DKYE DKLL label
  32. # 0 312.00 175237 154112.935 2.708 0
  33. # 1 795.84 300237 298252.945 2.979 0
  34. # 2 1444.20 150237 147339.130 2.708 0
  35. # 3 1417.14 350237 300653.780 2.708 0
  36. # 4 325.50 150237 145185.010 2.708 0
  37. #构建训练集
  38. X = train.loc[:,'XINGBIE':'DKLL']
  39. Y = train['label']
  40. from sklearn import tree
  41. clf = tree.DecisionTreeClassifier(
  42. max_depth=3,
  43. min_samples_leaf=50
  44. )
  45. clf = clf.fit(X, Y)
  46. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  47. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  48. # 二、决策树的可视化 #
  49. # 1、plot_tree(太丑,不推荐)
  50. #包里自带的,有点丑
  51. tree.plot_tree(clf)
  52. plt.show()
  53. # 2、graphviz对决策树进行可视化
  54. import graphviz
  55. dot_data = tree.export_graphviz(
  56. clf,
  57. out_file=None,
  58. feature_names=X.columns,
  59. class_names=['good','bad'],
  60. filled=True, rounded=True,
  61. special_characters=True)
  62. graph = graphviz.Source(dot_data)
  63. graph
  64. # 3、dtreeviz对决策树进行可视化
  65. # cannot import name 'dtreeviz' from 'dtreeviz.trees'
  66. # 这里代码 不对 现在不用这个dtreeviz.trees
  67. # from dtreeviz.trees import dtreeviz
  68. # testX = X.iloc[77,:]
  69. # viz = dtreeviz(clf,X,Y,
  70. # feature_names=np.array(X.columns),
  71. # # class_names=['good','bad'],
  72. # class_names={0:'good',1:'bad'},
  73. # X = testX)
  74. # viz.view()
  75. # 最新代码 是用 dtreeviz.model
  76. import dtreeviz
  77. # viz_model = dtreeviz.model(clf,
  78. # X_train=X, y_train=y,
  79. # feature_names=iris.feature_names,
  80. # target_name='iris',
  81. # class_names=iris.target_names)
  82. # v = viz_model.view() # render as SVG into internal object
  83. # v.show() # pop up window
  84. # v.save("/tmp/iris.svg") # optionally save as svg
  85. testX = X.iloc[77,:]
  86. # feature_names 要list
  87. viz = dtreeviz.model(clf,
  88. X_train=X, y_train=Y,
  89. feature_names=X.columns,
  90. class_names=['good','bad']
  91. # feature_names=np.array(X.columns),
  92. # class_names=['good','bad'],
  93. # class_names={0:'good',1:'bad'}
  94. )
  95. v=viz.view()
  96. v.show()
  97. # def model(model,
  98. # X_train,
  99. # y_train,
  100. # tree_index: int = None,
  101. # feature_names: List[str] = None,
  102. # target_name: str = None,
  103. # class_names: (List[str], Mapping[int, str]) = None)
  104. # 我们把树的深度再加深到5看看,树更复杂了
  105. from sklearn import tree
  106. clf = tree.DecisionTreeClassifier(
  107. max_depth=5,
  108. min_samples_leaf=50
  109. )
  110. clf = clf.fit(X, Y)
  111. testX = X.iloc[77,:]
  112. viz = dtreeviz.model(clf,
  113. X_train=X, y_train=Y,
  114. feature_names=X.columns,
  115. # feature_names=np.array(X.columns),
  116. # class_names=['good','bad'],
  117. class_names={0:'good',1:'bad'}
  118. )
  119. v=viz.view()
  120. v.show()
  121. # from dtreeviz import trees
  122. # from dtreeviz.models.sklearn_decision_trees import ShadowSKDTree
  123. # from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  124. # random_state=1234
  125. # dataset=X
  126. # tree_classifier = DecisionTreeClassifier(max_depth=4, random_state=random_state)
  127. # tree_classifier.fit(X, Y)
  128. # sk_dtree = ShadowSKDTree(tree_classifier, X, Y, X.columns, 'label', [0, 1])
  129. # trees.dtreeviz(tree_classifier, X, Y, X.columns, 'label', class_names=[0, 1])
  130. # trees.dtreeviz(sk_dtree, fancy=False)
  131. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  132. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
  133. # 四、决策规则提取 #
  134. # 1、决策树的生成的结构探索
  135. from sklearn.datasets import load_iris
  136. from sklearn import tree
  137. iris = load_iris()
  138. clf = tree.DecisionTreeClassifier()
  139. clf = clf.fit(iris.data, iris.target)
  140. clf.classes_
  141. [x for x in dir(clf) if not x.startswith('_')]
  142. dir(clf.tree_)
  143. # ['apply','capacity', 'children_left','children_right',
  144. # 'compute_feature_importances','compute_partial_dependence',
  145. # 'decision_path','feature',
  146. # 'impurity','max_depth',
  147. # 'max_n_classes','n_classes',
  148. # 'n_features','n_leaves',
  149. # 'n_node_samples','n_outputs','node_count',
  150. # 'predict','threshold',
  151. # 'value', 'weighted_n_node_samples']
  152. # 2、老方法提取决策树规则
  153. import pandas as pd
  154. import numpy as np
  155. pd.set_option('display.max_columns', None)#显示所有的列
  156. path = '/Users\yingtao.xiang\Downloads/train.csv'
  157. train = pd.read_csv(path).fillna(-1)
  158. train.columns
  159. X = train.loc[:,'XINGBIE':'DKLL']
  160. Y = train['label']
  161. from sklearn import tree
  162. clf = tree.DecisionTreeClassifier(
  163. max_depth=3,
  164. min_samples_leaf=50
  165. )
  166. clf = clf.fit(X, Y)
  167. from sklearn.tree import _tree
  168. def tree_to_code(tree, feature_names):
  169. tree_ = tree.tree_
  170. feature_name = [
  171. feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
  172. for i in tree_.feature
  173. ]
  174. print ("def tree({}):".format(", ".join(feature_names)))
  175. def recurse(node, depth):
  176. indent = " " * depth
  177. if tree_.feature[node] != _tree.TREE_UNDEFINED:
  178. name = feature_name[node]
  179. threshold = tree_.threshold[node]
  180. print("{}if {} <= {}:".format(indent, name, threshold))
  181. recurse(tree_.children_left[node], depth + 1)
  182. print("{}else: # if {} > {}".format(indent, name, threshold))
  183. recurse(tree_.children_right[node], depth + 1)
  184. else:
  185. print("{}return {}".format(indent, tree_.value[node]))
  186. recurse(0, 1)
  187. tree_to_code(clf,X.columns)
  188. # def tree(XINGBIE, CSNY, HYZK, ZHIYE, ZHICHEN, ZHIWU, XUELI, DWJJLX, DWSSHY, GRJCJS, GRZHZT, GRZHYE, GRZHSNJZYE, GRZHDNGJYE, GRYJCE, DWYJCE, DKFFE, DKYE, DKLL):
  189. # if GRZHZT <= 1.5:
  190. # if DWSSHY <= 14.5:
  191. # if DWJJLX <= 177.0:
  192. # return [[24563. 524.]]
  193. # else: # if DWJJLX > 177.0
  194. # return [[4969. 443.]]
  195. # else: # if DWSSHY > 14.5
  196. # if DWYJCE <= 747.4400024414062:
  197. # return [[5309. 482.]]
  198. # else: # if DWYJCE > 747.4400024414062
  199. # return [[2397. 1135.]]
  200. # else: # if GRZHZT > 1.5
  201. # if GRZHYE <= 3309.320068359375:
  202. # return [[ 0. 125.]]
  203. # else: # if GRZHYE > 3309.320068359375
  204. # return [[ 5. 48.]]
  205. # 新方法提取决策树规则
  206. def Get_Rules(clf,X):
  207. n_nodes = clf.tree_.node_count
  208. children_left = clf.tree_.children_left
  209. children_right = clf.tree_.children_right
  210. feature = clf.tree_.feature
  211. threshold = clf.tree_.threshold
  212. value = clf.tree_.value
  213. node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
  214. is_leaves = np.zeros(shape=n_nodes, dtype=bool)
  215. stack = [(0, 0)]
  216. while len(stack) > 0:
  217. node_id, depth = stack.pop()
  218. node_depth[node_id] = depth
  219. is_split_node = children_left[node_id] != children_right[node_id]
  220. if is_split_node:
  221. stack.append((children_left[node_id], depth+1))
  222. stack.append((children_right[node_id], depth+1))
  223. else:
  224. is_leaves[node_id] = True
  225. feature_name = [
  226. X.columns[i] if i != _tree.TREE_UNDEFINED else "undefined!"
  227. for i in clf.tree_.feature]
  228. ways = []
  229. depth = []
  230. feat = []
  231. nodes = []
  232. rules = []
  233. for i in range(n_nodes):
  234. if is_leaves[i]:
  235. while depth[-1] >= node_depth[i]:
  236. depth.pop()
  237. ways.pop()
  238. feat.pop()
  239. nodes.pop()
  240. if children_left[i-1]==i:#当前节点是上一个节点的左节点,则是小于
  241. a='{f}<={th}'.format(f=feat[-1],th=round(threshold[nodes[-1]],4))
  242. ways[-1]=a
  243. last =' & '.join(ways)+':'+str(value[i][0][0])+':'+str(value[i][0][1])
  244. rules.append(last)
  245. else:
  246. a='{f}>{th}'.format(f=feat[-1],th=round(threshold[nodes[-1]],4))
  247. ways[-1]=a
  248. last = ' & '.join(ways)+':'+str(value[i][0][0])+':'+str(value[i][0][1])
  249. rules.append(last)
  250. else: #不是叶子节点 入栈
  251. if i==0:
  252. ways.append(round(threshold[i],4))
  253. depth.append(node_depth[i])
  254. feat.append(feature_name[i])
  255. nodes.append(i)
  256. else:
  257. while depth[-1] >= node_depth[i]:
  258. depth.pop()
  259. ways.pop()
  260. feat.pop()
  261. nodes.pop()
  262. if i==children_left[nodes[-1]]:
  263. w='{f}<={th}'.format(f=feat[-1],th=round(threshold[nodes[-1]],4))
  264. else:
  265. w='{f}>{th}'.format(f=feat[-1],th=round(threshold[nodes[-1]],4))
  266. ways[-1] = w
  267. ways.append(round(threshold[i],4))
  268. depth.append(node_depth[i])
  269. feat.append(feature_name[i])
  270. nodes.append(i)
  271. return rules
  272. # 利用函数对规则进行提取
  273. #训练一个决策树,对规则进行提取
  274. clf = tree.DecisionTreeClassifier(max_depth=10,
  275. min_samples_leaf=50)
  276. clf = clf.fit(X, Y)
  277. Rules = Get_Rules(clf,X)
  278. Rules[0:5] # 查看前5条规则
  279. # ['GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE<=663.54 & DKYE<=67419.1094:45.0:8.0',
  280. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE<=663.54 & DKYE >67419.1094:61.0:3.0',
  281. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE >663.54 & GRZHYE<=45622.4883 & DKYE<=1825.5625:63.0:2.0',
  282. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE >663.54 & GRZHYE<=45622.4883 & DKYE >1825.5625:188.0:0.0',
  283. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE >663.54 & GRZHYE >45622.4883:46.0:4.0']
  284. len(Rules) # 查看规则总数
  285. # 182
  286. # 提高树的深度再看看,max_depth=15,可以看到规则数从182变成了521条,规模更大
  287. clf = tree.DecisionTreeClassifier(max_depth=15,min_samples_leaf=20)
  288. clf = clf.fit(X, Y)
  289. Rules = Get_Rules(clf,X)
  290. Rules[0:5]
  291. # ['GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE<=663.54 & GRZHSNJZYE<=19428.9082 & DKFFE<=142737.0 & CSNY<=600926400.0:54.0:0.0',
  292. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE<=663.54 & GRZHSNJZYE<=19428.9082 & DKFFE<=142737.0 & CSNY >600926400.0:18.0:2.0',
  293. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE<=663.54 & GRZHSNJZYE<=19428.9082 & DKFFE >142737.0:19.0:4.0',
  294. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE<=663.54 & GRZHSNJZYE >19428.9082:15.0:5.0',
  295. # 'GRZHZT<=1.5 & DWSSHY<=14.5 & DWJJLX<=177.0 & DWJJLX<=115.0 & DKYE<=111236.2852 & DWSSHY<=4.5 & DWYJCE >663.54 & GRZHYE<=73608.0156 & DKYE<=1825.5625 & GRZHSNJZYE<=9524.7949:21.0:2.0']
  296. len(Rules)
  297. # 521
  298. #可以遍历所有的规则
  299. for i in Rules:
  300. print(i)
  301. pd.DataFrame(Rules).to_excel('rules.xlsx',index=False)
  302. #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

一、数据说明及读取

1、数据集信息

数据从真实场景和实际应用出发,利用个人的基本身份信息、个人的住房公积金缴存和贷款等数据信息,来建立准确的风险控制模型,来预测用户是否会逾期还款。一共提供了40000带标签训练集样本,数据仅有一张表,一共有19个基本特征,且均不包含任何缺失值。

2、数据属性信息

标签:label是否逾期(是 = 1,否 = 0)。
特征:包含以下19个变量,名称和含义如下。

序号 字段名 类型 说明
1 id String 主键
2 XINGBIE int 性别
3 CSNY int 出生年月
4 HYZK int 婚姻状况
5 ZHIYE int 职业
6 ZHICHEN int 职称
7 ZHIWU int 职务
8 XUELI int 学历
9 DWJJLX int 单位经济类型
10 DWSSHY int 单位所属行业
11 GRJCJS float 个人缴存基数
12 GRZHZT int 个人账户状态
13 GRZHYE float 个人账户余额
14 GRZHSNJZYE float 个人账户上年结转余额
15 GRZHDNGJYE float 个人账户当年归集余额
16 GRYICE float 个人月缴存额
17 DWYICE float 单位月缴存额
18 DKFFE int 贷款发放额
19 DKYE float 贷款余额
20 DKLL float 贷款利率
21 label int 是否逾期 (0代表没逾期1代表逾期)

3、读取数据

二、构建决策树

三、决策树的可视化

决策树可视化的方案比较多,都写出来给对比看看,推荐第二种和第三种。对决策树进行可视化,是非常有必要的,能够帮助我们自己充分理解决策树的生成过程,如果是风控,也有利于咱们给业务部门解释数据和结果。

1、plot_tree(太丑,不推荐)

*决策树三种可视化展示形态 - 图3

2、graphviz对决策树进行可视化

如果搜索“可视化决策树”,很快便能找到由scikit提供的基于Python语言的解决方案:sklearn.tree.export_graphviz,这个也是最常用的解决方案
image.png

生成的决策树解释
1)samples:节点中观察的数量,比如根节点40000,表示数据集总共有4万个样本
2)有多少种类别,整棵树的叶子就有多少种颜色,比如我们这里有2个类别,颜色对应是黄、绿、Gini指数越小,该节点颜色越深,也就是纯度越高。
3)value表示当前节点2种类别的样本有多少,比如下面第一棵树的根节点,value = [37243,2757],表示有37243个好样本,2757坏样本
4)class表示当前那个类别的样本最多,比如下面最右边的一棵树的根节点,class = bad,可以看到当前节点它的坏样本数是最多的。
5)gini:节点的基尼不纯度。当沿着树向下移动时,平均加权的基尼不纯度必须降低。

3、dtreeviz对决策树进行可视化

3.1 dtreeviz 的包详解

https://github.com/parrt/dtreeviz

dtreeviz是我认为非常完美的决策树可视化的包,非常好理解,也非常美观。下面我们看看这个包可视化的结果。


*决策树三种可视化展示形态 - 图5


*决策树三种可视化展示形态 - 图6


*决策树三种可视化展示形态 - 图7

只展示了预测的路径
*决策树三种可视化展示形态 - 图8

有了决策树的可视化,我们就能直接得到每条策略了,当时人为的看,效率还是比较低,我们需要更高效的方式,对数据决策树上的信息进行提取,直接得到规则。

四、决策规则提取

1、决策树的生成的结构探索

要提取出来其中的规则,我们需要探索决策树的存储结构,为了探究sklearn中决策树是如何设计和实现的,以分类决策树为例,首先看下决策树都内置了哪些属性和接口:通过dir属性查看一颗初始的决策树都包含了哪些属性(这里过滤掉了以”_”开头的属性,因为一般是内置私有属性),得到结果如下:

*决策树三种可视化展示形态 - 图9
大致浏览上述结果,属性主要是决策树初始化时的参数,例如ccp_alpha:剪枝系数,class_weight:类的权重,criterion:分裂准则等;还有就是决策树实现的主要函数,例如fit:模型训练,predict:模型预测等等。
本文的重点是探究决策树中是如何保存训练后的”那颗树”,所以我们进一步用鸢尾花数据集对决策树进行训练一下,而后再次调用dir函数,看看增加了哪些属性和接口:
本文的重点是探究决策树中是如何保存训练后的”那颗树”,所以我们进一步用鸢尾花数据集对决策树进行训练一下,而后再次调用dir函数,看看增加了哪些属性和接口:
*决策树三种可视化展示形态 - 图10
通过集合的差集,很明显看出训练前后的决策树主要是增加了6个属性(都是属性,而非函数功能),其中通过属性名字也很容易推断其含义:

  • classes_:分类标签的取值,即y的唯一值集合
  • maxfeatures:最大特征数
  • nclasses:类别数,如2分类或多分类等,即classes_属性中的长度
  • nfeatures_in:输入特征数量,等价于老版sklearn中的nfeatures,现已弃用,并推荐nfeatures_in
  • n_outputs:多输出的个数,即决策树不仅可以用于实现单一的分类问题,还可同时实现多个分类问题,例如给定一组人物特征,用于同时判断其是男/女、胖/瘦和高矮,这是3个分类问题,即3输出(需要区别理解多分类和多输出任务)
  • tree:毫无疑问,这个tree就是今天本文的重点,是在决策树训练之后新增的属性集,其中存储了决策树是如何存储的。

那我们对这个tree属性做进一步探究,首先打印该tree属性发现,这是一个Tree对象,并给出了在sklearn中的文件路径:
*决策树三种可视化展示形态 - 图11
我们可以通过help方法查看Tree类的介绍:
*决策树三种可视化展示形态 - 图12
通过上述doc文档,其中第一句就很明确的对决策树做了如下描述:
Array-based representation of a binary decision tree.
即:基于数组表示的二分类决策树,也就是二叉树!进一步地,在这个二叉树中,数组的第i个元素代表了决策树的第i个节点的信息,节点0表示决策树的根节点。那么每个节点又都蕴含了什么信息呢?我们注意到上述文档中列出了节点的文件名:_tree.pxd,查看其中,很容易发现节点的定义如下:
*决策树三种可视化展示形态 - 图13
虽然是cython的定义语法,但也不难推断其各属性字段的类型和含义,例如:

  • left_child:size类型(无符号整型),代表了当前节点的左子节点的索引
  • right_child:类似于left_child
  • feature:size类型,代表了当前节点用于分裂的特征索引,即在训练集中用第几列特征进行分裂
  • threshold:double类型,代表了当前节点选用相应特征时的分裂阈值,一般是≤该阈值时进入左子节点,否则进入右子节点
  • n_node_samples:size类型,代表了训练时落入到该节点的样本总数。显然,父节点的n_node_samples将等于其左右子节点的n_node_samples之和。

至此,决策树中单个节点的属性定义和实现基本推断完毕,那么整个决策树又是如何将所有节点串起来的呢?我们再次诉诸于训练后决策树的tree_属性,看看它都哪些接口,仍然过滤掉内置私有属性,得到如下结果:

  • 训练后的决策树共包含5个节点,其中3个叶子节点
  • 通过children_left和children_right两个属性,可以知道第0个节点(也就是根节点)的左子节点索引为1,右子节点索引为2,;第1个节点的左右子节点均为-1,意味着该节点即为叶子节点;第2个节点的左右子节点分别为3和4,说明它是一个内部节点,并做了进一步分裂
  • 通过feature和threshold两个属性,可以知道第0个节点(根节点)使用索引为3的特征(对应第4列特征)进行分裂,且其最优分割阈值为0.8;第1个节点因为是叶子节点,所以不再分裂,其对应feature和threshold字段均为-2
  • 通过value属性,可以查看落入每个节点的各类样本数量,由于鸢尾花数据集是一个三分类问题,且该决策树共有5个节点,所以value的取值为一个5×3的二维数组,例如第一行代表落入根节点的样本计数为[50, 50, 50],第二行代表落入左子节点的样本计数为[50, 0, 0],由于已经是纯的了,所以不再继续分裂。
  • 另外,tree中实际上并未直接标出各叶节点所对应的标签值,但完全可通过value属性来得到,即各叶子节点中落入样本最多的类别即为相应标签。甚至说,不仅可知道对应标签,还可通过计算数量之比得到相应的概率!

拿鸢尾花数据集手动验证一下上述猜想,以根节点的分裂特征3和阈值0.8进行分裂,得到落入左子节点的样本计数结果如下,发现确实是分裂后只剩下50个第一类样本,也即样本计数为[50, 0, 0],完全一致。
*决策树三种可视化展示形态 - 图14
另外,通过children_left和children_right两个属性的子节点对应关系,其实我们还可以推断出该二叉树的遍历方式为前序遍历,即按照根-左-右的顺序,对于上述决策树其分裂后对应二叉树示意图如下:
*决策树三种可视化展示形态 - 图15

2、老方法提取决策树规则

通过上面的分析,我们知道了决策树的存储方式,下面就开始规则提取,在网上搜索,基本上只能收到下面的方法:

*决策树三种可视化展示形态 - 图16
’可以看得出来,虽然提取出来了策略,但是并不是很完美,还是需要人为拆解策略,于是我继续研究。

3、新方法提取决策树规则

不仅仅对对二叉树的所有路径进行遍历,还需要进行回溯并组合成变量,根据决策树的输出,构建规则提取函数,需要用到二叉树遍历和回溯算法,本人数据结构不是很好,干了两个晚上,真是编程偷不了懒,出来混迟早要还的。代码比较混乱,大家将就看,还好这个部分对效率没啥要求。如果有更好的代码,不吝赐教。
*决策树三种可视化展示形态 - 图17

*决策树三种可视化展示形态 - 图18
*决策树三种可视化展示形态 - 图19
*决策树三种可视化展示形态 - 图20