什么是决策树

想象一样生活中的一个场景,妈妈给女儿介绍男朋友
女儿:长得帅不帅?
妈妈:挻帅的。
女儿:有没有房子?
妈妈:在老家有一个。
女儿:收入怎么样?
妈妈:还不错,年薪百万。
女儿:做什么工作的?
妈妈:IT,互联网公司做数据挖掘的。
女儿:好,我见见。
现实生活中,我们会遇见各种选择,都是基于经验来做判断,如果把判断背后的逻辑整理成一个结构图,你会发现实际上就是一个树状图,这就是决策树

可能上面的说法还不是很具体,决策树在sklearn中简单使用如下:

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn import datasets
  4. # 使用鸢尾花数据
  5. iris = datasets.load_iris()
  6. X = iris.data[:, 2:] # 使用数据后两个特征
  7. y = iris.target
  8. # 画图
  9. plt.scatter(X[y==0, 0], X[y==0, 1])
  10. plt.scatter(X[y==1, 0], X[y==1, 1])
  11. plt.scatter(X[y==2, 0], X[y==2, 1])
  12. plt.show()

clf.png
从散点图可以很清楚的看到鸢尾花三个分类的数据特征非常明显。

  1. # 使用sklearn的决策树进行分类学习
  2. from sklearn.tree import DecisionTreeClassifier
  3. dt_clf = DecisionTreeClassifier(max_depth=2, criterion="entropy")
  4. dt_clf.fit(X, y)
  5. # 画出决策分类
  6. # 决策边界
  7. def plot_decision_boundary(model, axis):
  8. x0, x1 = np.meshgrid(
  9. np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),
  10. np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),
  11. )
  12. X_new = np.c_[x0.ravel(), x1.ravel()]
  13. y_predict = model.predict(X_new)
  14. zz = y_predict.reshape(x0.shape)
  15. from matplotlib.colors import ListedColormap
  16. custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
  17. plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
  18. plot_decision_boundary(dt_clf, axis=[0.5, 7.5, 0, 3])
  19. plt.scatter(X[y==0, 0], X[y==0, 1])
  20. plt.scatter(X[y==1, 0], X[y==1, 1])
  21. plt.scatter(X[y==2, 0], X[y==2, 1])
  22. plt.show()

boundary.png决策树.png
如图,sklearn的决策方式可以用树模型可以表示
小结:

  • 非参数学习算法
  • 可以解决分类问题,天然可以解决多分类问题
  • 也可解决回归问题
  • 非常好的可解释性

    构造决策树要解决的三个问题

    每个节点在哪个维度做划分
    某个维度在哪个值上做划分
    什么时候停止,并得到目标状态

    决策树构造过程

    信息熵

    熵在信息论中代表随机变量不确定度的度量。
    通俗的讲:
    熵越大,数据的不确定性越高
    熵越小,数据的不确定性低