CART,又名分类回归树,是在ID3的基础上进行优化的决策树,学习CART记住以下几个关键点:

  • (1)CART既能是分类树,又能是分类树;
  • (2)当CART是分类树时,采用GINI值作为节点分裂的依据;当CART是回归树时,采用样本的最小方差作为节点分裂的依据;
  • (3)CART是一棵二叉树。

接下来将以一个实际的例子对CART进行介绍:

表1 原始数据表

看电视时间 婚姻情况 职业 年龄
3 未婚 学生 12
4 未婚 学生 18
2 已婚 老师 26
5 已婚 上班族 47
2.5 已婚 上班族 36
3.5 未婚 老师 29
4 已婚 学生 21

**
从以下的思路理解CART:

分类树?回归树?

分类树的作用是通过一个对象的特征来预测该对象所属的类别,而回归树的目的是根据一个对象的信息预测该对象的属性,并以数值表示。

CART既能是分类树,又能是决策树,如上表所示,如果我们想预测一个人是否已婚,那么构建的CART将是分类树;如果想预测一个人的年龄,那么构建的将是回归树。

分类树和回归树是怎么做决策的?假设我们构建了两棵决策树分别预测用户是否已婚和实际的年龄,如图1和图2所示:
决策树系列(五)——CART - 图1
图1 预测婚姻情况决策树 图2 预测年龄的决策树

图1表示一棵分类树,其叶子节点的输出结果为一个实际的类别,在这个例子里是婚姻的情况(已婚或者未婚),选择叶子节点中数量占比最大的类别作为输出的类别;

图2是一棵回归树,预测用户的实际年龄,是一个具体的输出值。怎样得到这个输出值?一般情况下选择使用中值、平均值或者众数进行表示,图2使用节点年龄数据的平均值作为输出值。

CART如何选择分裂的属性?

分裂的目的是为了能够让数据变纯,使决策树输出的结果更接近真实值。那么CART是如何评价节点的纯度呢?如果是分类树,CART采用GINI值衡量节点纯度;如果是回归树,采用样本方差衡量节点纯度。节点越不纯,节点分类或者预测的效果就越差。

GINI值的计算公式:

决策树系列(五)——CART - 图2

节点越不纯,GINI值越大。以二分类为例,如果节点的所有数据只有一个类别,则决策树系列(五)——CART - 图3 ,如果两类数量相同,则决策树系列(五)——CART - 图4

回归方差计算公式:

决策树系列(五)——CART - 图5
方差越大,表示该节点的数据越分散,预测的效果就越差。如果一个节点的所有数据都相同,那么方差就为0,此时可以很肯定得认为该节点的输出值;如果节点的数据相差很大,那么输出的值有很大的可能与实际值相差较大。

因此,无论是分类树还是回归树,CART都要选择使子节点的GINI值或者回归方差最小的属性作为分裂的方案。即最小化(分类树):
决策树系列(五)——CART - 图6
或者(回归树):
决策树系列(五)——CART - 图7

CART如何分裂成一棵二叉树?

节点的分裂分为两种情况,连续型的数据和离散型的数据。

CART对连续型属性的处理与C4.5差不多,通过最小化分裂后的GINI值或者样本方差寻找最优分割点,将节点一分为二,在这里不再叙述,详细请看C4.5

对于离散型属性,理论上有多少个离散值就应该分裂成多少个节点。但CART是一棵二叉树,每一次分裂只会产生两个节点,怎么办呢?很简单,只要将其中一个离散值独立作为一个节点,其他的离散值生成另外一个节点即可。这种分裂方案有多少个离散值就有多少种划分的方法,举一个简单的例子:如果某离散属性一个有三个离散值X,Y,Z,则该属性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分别计算每种划分方法的基尼值或者样本方差确定最优的方法。

以属性“职业”为例,一共有三个离散值,“学生”、“老师”、“上班族”。该属性有三种划分的方案,分别为{“学生”}、{“老师”、“上班族”},{“老师”}、{“学生”、“上班族”},{“上班族”}、{“学生”、“老师”},分别计算三种划分方案的子节点GINI值或者样本方差,选择最优的划分方法,如下图所示:

第一种划分方法:{“学生”}、{“老师”、“上班族”}
决策树系列(五)——CART - 图8
预测是否已婚(分类):
决策树系列(五)——CART - 图9
预测年龄(回归):
     决策树系列(五)——CART - 图10


第二种划分方法:{“老师”}、{“学生”、“上班族”}
决策树系列(五)——CART - 图11

预测是否已婚(分类):
决策树系列(五)——CART - 图12

预测年龄(回归):
     决策树系列(五)——CART - 图13

第三种划分方法:{“上班族”}、{“学生”、“老师”}**
决策树系列(五)——CART - 图14
预测是否已婚(分类):
决策树系列(五)——CART - 图15
预测年龄(回归):
    决策树系列(五)——CART - 图16
综上,如果想预测是否已婚,则选择{“上班族”}、{“学生”、“老师”}的划分方法,如果想预测年龄,则选择{“老师”}、{“学生”、“上班族”}的划分方法。

如何剪枝?

CART采用CCP(代价复杂度)剪枝方法。代价复杂度选择节点表面误差率增益值最小的非叶子节点,删除该非叶子节点的左右子节点,若有多个非叶子节点的表面误差率增益值相同小,则选择非叶子节点中子节点数最多的非叶子节点进行剪枝。
可描述如下:

令决策树的非叶子节点为决策树系列(五)——CART - 图17

  • a)计算所有非叶子节点的表面误差率增益值决策树系列(五)——CART - 图18
  • b)选择表面误差率增益值决策树系列(五)——CART - 图19最小的非叶子节点决策树系列(五)——CART - 图20(若多个非叶子节点具有相同小的表面误差率增益值,选择节点数最多的非叶子节点)。
  • c)对决策树系列(五)——CART - 图21进行剪枝

表面误差率增益值的计算公式:

决策树系列(五)——CART - 图22

其中:

  • 决策树系列(五)——CART - 图23表示叶子节点的误差代价,决策树系列(五)——CART - 图24决策树系列(五)——CART - 图25 为节点的错误率, 决策树系列(五)——CART - 图26为节点数据量的占比;
  • 决策树系列(五)——CART - 图27表示子树的误差代价,决策树系列(五)——CART - 图28决策树系列(五)——CART - 图29为子节点i的错误率,决策树系列(五)——CART - 图30 表示节点i的数据节点占比;
  • 决策树系列(五)——CART - 图31表示子树节点个数。

算例:
下图是其中一颗子树,设决策树的总数据量为40。
决策树系列(五)——CART - 图32

该子树的表面误差率增益值可以计算如下:
决策树系列(五)——CART - 图33

求出该子树的表面错误覆盖率为 ,只要求出其他子树的表面误差率增益值就可以对决策树进行剪枝。

程序实际以及源代码


流程图:
决策树系列(五)——CART - 图34**

(1)数据处理

对原始的数据进行数字化处理,并以二维数据的形式存储,每一行表示一条记录,前n-1列表示属性,最后一列表示分类的标签。

如表1的数据可以转化为表2:

表2 初始化后的数据

看电视时间 婚姻情况 职业 年龄
3 未婚 学生 12
4 未婚 学生 18
2 已婚 老师 26
5 已婚 上班族 47
2.5 已婚 上班族 36
3.5 未婚 老师 29
4 已婚 学生 21
  1. <br />其中,对于“婚姻情况”属性,数字{1,2}分别表示{未婚,已婚 };对于“职业”属性{1,2,3, }分别表示{学生、老师、上班族};

代码如下所示:
static double[][] allData; //存储进行训练的数据
static List[] featureValues; //离散属性对应的离散值

featureValues是链表数组,数组的长度为属性的个数,数组的每个元素为该属性的离散值链表。

(2)两个类:节点类和分裂信息

a)节点类Node

  1. <br />该类表示一个节点,属性包括节点选择的分裂属性、节点的输出类、孩子节点、深度等。注意,与ID3中相比,新增了两个属性:leafWrong和leafNode_Count分别表示叶子节点的总分类误差和叶子节点的个数,主要是为了方便剪枝。
  1. class Node
  2. {
  3. /// <summary>
  4. /// 每一个节点的分裂值
  5. /// </summary>
  6. public List<String> features { get; set; }
  7. /// <summary>
  8. /// 分裂属性的类型{离散、连续}
  9. /// </summary>
  10. public String feature_Type { get; set; }
  11. /// <summary>
  12. /// 分裂属性的下标
  13. /// </summary>
  14. public String SplitFeature { get; set; }
  15. //List<int> nums = new List<int>(); //行序号
  16. /// <summary>
  17. /// 每一个类对应的数目
  18. /// </summary>
  19. public double[] ClassCount { get; set; }
  20. //int[] isUsed = new int[0]; //属性的使用情况 1:已用 2:未用
  21. /// <summary>
  22. /// 孩子节点
  23. /// </summary>
  24. public List<Node> childNodes { get; set; }
  25. Node Parent = null;
  26. /// <summary>
  27. /// 该节点占比最大的类别
  28. /// </summary>
  29. public String finalResult { get; set; }
  30. /// <summary>
  31. /// 树的深度
  32. /// </summary>
  33. public int deep { get; set; }
  34. /// <summary>
  35. /// 最大的类下标
  36. /// </summary>
  37. public int result { get; set; }
  38. /// <summary>
  39. /// 子节点误差
  40. /// </summary>
  41. public int leafWrong { get; set; }
  42. /// <summary>
  43. /// 子节点数目
  44. /// </summary>
  45. public int leafNode_Count { get; set; }
  46. /// <summary>
  47. /// 数据量
  48. /// </summary>
  49. public int rowCount { get; set; }
  50. public void setClassCount(double[] count)
  51. {
  52. this.ClassCount = count;
  53. double max = ClassCount[0];
  54. int result = 0;
  55. for (int i = 1; i < ClassCount.Length; i++)
  56. {
  57. if (max < ClassCount[i])
  58. {
  59. max = ClassCount[i];
  60. result = i;
  61. }
  62. }
  63. this.result = result;
  64. }
  65. public double getErrorCount()
  66. {
  67. return rowCount - ClassCount[result];
  68. }
  69. }
  70. 树的节点

b)分裂信息类

该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。

  1. class SplitInfo
  2. {
  3. /// <summary>
  4. /// 分裂的属性下标
  5. /// </summary>
  6. public int splitIndex { get; set; }
  7. /// <summary>
  8. /// 数据类型
  9. /// </summary>
  10. public int type { get; set; }
  11. /// <summary>
  12. /// 分裂属性的取值
  13. /// </summary>
  14. public List<String> features { get; set; }
  15. /// <summary>
  16. /// 各个节点的行坐标链表
  17. /// </summary>
  18. public List<int>[] temp { get; set; }
  19. /// <summary>
  20. /// 每个节点各类的数目
  21. /// </summary>
  22. public double[][] class_Count { get; set; }
  23. }
  24. 分裂信息

主方法findBestSplit(Node node,List nums,int[] isUsed),该方法对节点进行分裂
其中:

  • node表示即将进行分裂的节点;
  • nums表示节点数据对一个的行坐标列表;
  • isUsed表示到该节点位置所有属性的使用情况;

findBestSplit的这个方法主要有以下几个组成部分:
1)节点分裂停止的判定
节点分裂条件如上文所述,源代码如下:

  1. public static bool ifEnd(Node node, double shang,int[] isUsed)
  2. {
  3. try
  4. {
  5. double[] count = node.ClassCount;
  6. int rowCount = node.rowCount;
  7. int maxResult = 0;
  8. double maxRate = 0;
  9. #region 数达到某一深度
  10. int deep = node.deep;
  11. if (deep >= 10)
  12. {
  13. maxResult = node.result + 1;
  14. node.feature_Type="result";
  15. node.features=new List<String>() { maxResult + ""
  16. };
  17. node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
  18. node.leafNode_Count=1;
  19. return true;
  20. }
  21. #endregion
  22. #region 纯度(其实跟后面的有点重了,记得要修改)
  23. //maxResult = 1;
  24. //for (int i = 1; i < count.Length; i++)
  25. //{
  26. // if (count[i] / rowCount >= 0.95)
  27. // {
  28. // node.feature_Type="result";
  29. // node.features=new List<String> { "" + (i +
  30. 1) };
  31. // node.leafNode_Count=1;
  32. // node.leafWrong=rowCount - Convert.ToInt32
  33. (count[i]);
  34. // return true;
  35. // }
  36. //}
  37. #endregion
  38. #region 熵为0
  39. if (shang == 0)
  40. {
  41. maxRate = count[0] / rowCount;
  42. maxResult = 1;
  43. for (int i = 1; i < count.Length; i++)
  44. {
  45. if (count[i] / rowCount >= maxRate)
  46. {
  47. maxRate = count[i] / rowCount;
  48. maxResult = i + 1;
  49. }
  50. }
  51. node.feature_Type="result";
  52. node.features=new List<String> { maxResult + ""
  53. };
  54. node.leafWrong=rowCount - Convert.ToInt32(count
  55. [maxResult - 1]);
  56. node.leafNode_Count=1;
  57. return true;
  58. }
  59. #endregion
  60. #region 属性已经分完
  61. //int[] isUsed = node.getUsed();
  62. bool flag = true;
  63. for (int i = 0; i < isUsed.Length - 1; i++)
  64. {
  65. if (isUsed[i] == 0)
  66. {
  67. flag = false;
  68. break;
  69. }
  70. }
  71. if (flag)
  72. {
  73. maxRate = count[0] / rowCount;
  74. maxResult = 1;
  75. for (int i = 1; i < count.Length; i++)
  76. {
  77. if (count[i] / rowCount >= maxRate)
  78. {
  79. maxRate = count[i] / rowCount;
  80. maxResult = i + 1;
  81. }
  82. }
  83. node.feature_Type=("result");
  84. node.features=(new List<String> { "" +
  85. (maxResult) });
  86. node.leafWrong=(rowCount - Convert.ToInt32(count
  87. [maxResult - 1]));
  88. node.leafNode_Count=(1);
  89. return true;
  90. }
  91. #endregion
  92. #region 几点数少于100
  93. if (rowCount < Limit_Node)
  94. {
  95. maxRate = count[0] / rowCount;
  96. maxResult = 1;
  97. for (int i = 1; i < count.Length; i++)
  98. {
  99. if (count[i] / rowCount >= maxRate)
  100. {
  101. maxRate = count[i] / rowCount;
  102. maxResult = i + 1;
  103. }
  104. }
  105. node.feature_Type="result";
  106. node.features=new List<String> { "" + (maxResult)
  107. };
  108. node.leafWrong=rowCount - Convert.ToInt32(count
  109. [maxResult - 1]);
  110. node.leafNode_Count=1;
  111. return true;
  112. }
  113. #endregion
  114. return false;
  115. }
  116. catch (Exception e)
  117. {
  118. return false;
  119. }
  120. }
  121. 停止分裂的条件

2)寻找最优的分裂属性
寻找最优的分裂属性需要计算每一个分裂属性分裂后的GINI值或者样本方差,计算公式上文已给出,其中GINI值的计算代码如下:

  1. public static double getGini(double[] counts, int countAll)
  2. {
  3. double Gini = 1;
  4. for (int i = 0; i < counts.Length; i++)
  5. {
  6. Gini = Gini - Math.Pow(counts[i] / countAll, 2);
  7. }
  8. return Gini;
  9. }
  10. GINI值计算

3)进行分裂,同时对子节点进行迭代处理
**
其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。
findBestSplit源代码:

  1. public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
  2. {
  3. try
  4. {
  5. //判断是否继续分裂
  6. double totalShang = getGini(node.ClassCount, node.rowCount);
  7. if (ifEnd(node, totalShang, isUsed))
  8. {
  9. return node;
  10. }
  11. #region 变量声明
  12. SplitInfo info = new SplitInfo();
  13. info.initial();
  14. int RowCount = nums.Count; //样本总数
  15. double jubuMax = 1; //局部最大熵
  16. int splitPoint = 0; //分裂的点
  17. double splitValue = 0; //分裂的值
  18. #endregion
  19. for (int i = 0; i < isUsed.Length - 1; i++)
  20. {
  21. if (isUsed[i] == 1)
  22. {
  23. continue;
  24. }
  25. #region 离散变量
  26. if (type[i] == 0)
  27. {
  28. double[][] allCount = new double[allNum[i]][];
  29. for (int j = 0; j < allCount.Length; j++)
  30. {
  31. allCount[j] = new double[classCount];
  32. }
  33. int[] countAllFeature = new int[allNum[i]];
  34. List<int>[] temp = new List<int>[allNum[i]];
  35. double[] allClassCount = node.ClassCount; //所有类别的数量
  36. for (int j = 0; j < temp.Length; j++)
  37. {
  38. temp[j] = new List<int>();
  39. }
  40. for (int j = 0; j < nums.Count; j++)
  41. {
  42. int index = Convert.ToInt32(allData[nums[j]][i]);
  43. temp[index - 1].Add(nums[j]);
  44. countAllFeature[index - 1]++;
  45. allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
  46. }
  47. double allShang = 1;
  48. int choose = 0;
  49. double[][] jubuCount = new double[2][];
  50. for (int k = 0; k < allCount.Length; k++)
  51. {
  52. if (temp[k].Count == 0)
  53. continue;
  54. double JubuShang = 0;
  55. double[][] tempCount = new double[2][];
  56. tempCount[0] = allCount[k];
  57. tempCount[1] = new double[allCount[0].Length];
  58. for (int j = 0; j < tempCount[1].Length; j++)
  59. {
  60. tempCount[1][j] = allClassCount[j] - allCount[k][j];
  61. }
  62. JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
  63. int nodecount = RowCount - countAllFeature[k];
  64. JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
  65. if (JubuShang < allShang)
  66. {
  67. allShang = JubuShang;
  68. jubuCount = tempCount;
  69. choose = k;
  70. }
  71. }
  72. if (allShang < jubuMax)
  73. {
  74. info.type = 0;
  75. jubuMax = allShang;
  76. info.class_Count = jubuCount;
  77. info.temp[0] = temp[choose];
  78. info.temp[1] = new List<int>();
  79. info.features = new List<string>();
  80. info.features.Add((choose + 1) + "");
  81. info.features.Add("");
  82. for (int j = 0; j < temp.Length; j++)
  83. {
  84. if (j == choose)
  85. continue;
  86. for (int k = 0; k < temp[j].Count; k++)
  87. {
  88. info.temp[1].Add(temp[j][k]);
  89. }
  90. if (temp[j].Count != 0)
  91. {
  92. info.features[1] = info.features[1] + (j + 1) + ",";
  93. }
  94. }
  95. info.splitIndex = i;
  96. }
  97. }
  98. #endregion
  99. #region 连续变量
  100. else
  101. {
  102. double[] leftCunt = new double[classCount];
  103. //做节点各个类别的数量
  104. double[] rightCount = new double[classCount];
  105. //右节点各个类别的数量
  106. double[] count1 = new double[classCount];
  107. //子集1的统计量
  108. double[] count2 = new double
  109. [node.ClassCount.Length]; //子集2的统计量
  110. for (int j = 0; j < node.ClassCount.Length;
  111. j++)
  112. {
  113. count2[j] = node.ClassCount[j];
  114. }
  115. int all1 = 0;
  116. //子集1的样本量
  117. int all2 = nums.Count;
  118. //子集2的样本量
  119. double lastValue = 0;
  120. //上一个记录的类别
  121. double currentValue = 0;
  122. //当前类别
  123. double lastPoint = 0;
  124. //上一个点的值
  125. double currentPoint = 0;
  126. //当前点的值
  127. double[] values = new double[nums.Count];
  128. for (int j = 0; j < values.Length; j++)
  129. {
  130. values[j] = allData[nums[j]][i];
  131. }
  132. QSort(values, nums, 0, nums.Count - 1);
  133. double lianxuMax = 1;
  134. //连续型属性的最大熵
  135. #region 寻找最佳的分割点
  136. for (int j = 0; j < nums.Count - 1; j++)
  137. {
  138. currentValue = allData[nums[j]][lieshu -
  139. 1];
  140. currentPoint = (allData[nums[j]][i]);
  141. if (j == 0)
  142. {
  143. lastValue = currentValue;
  144. lastPoint = currentPoint;
  145. }
  146. if (currentValue != lastValue &&
  147. currentPoint != lastPoint)
  148. {
  149. double shang1 = getGini(count1,
  150. all1);
  151. double shang2 = getGini(count2,
  152. all2);
  153. double allShang = shang1 * all1 /
  154. (all1 + all2) + shang2 * all2 / (all1 + all2);
  155. //allShang = (totalShang - allShang);
  156. if (lianxuMax > allShang)
  157. {
  158. lianxuMax = allShang;
  159. for (int k = 0; k <
  160. count1.Length; k++)
  161. {
  162. leftCunt[k] = count1[k];
  163. rightCount[k] = count2[k];
  164. }
  165. splitPoint = j;
  166. splitValue = (currentPoint +
  167. lastPoint) / 2;
  168. }
  169. }
  170. all1++;
  171. count1[Convert.ToInt32(currentValue) -
  172. 1]++;
  173. count2[Convert.ToInt32(currentValue) -
  174. 1]--;
  175. all2--;
  176. lastValue = currentValue;
  177. lastPoint = currentPoint;
  178. }
  179. #endregion
  180. #region 如果超过了局部值,重设
  181. if (lianxuMax < jubuMax)
  182. {
  183. info.type = 1;
  184. info.splitIndex = i;
  185. info.features=new List<string>()
  186. {splitValue+""};
  187. //finalPoint = splitPoint;
  188. jubuMax = lianxuMax;
  189. info.temp[0] = new List<int>();
  190. info.temp[1] = new List<int>();
  191. for (int k = 0; k < splitPoint; k++)
  192. {
  193. info.temp[0].Add(nums[k]);
  194. }
  195. for (int k = splitPoint; k < nums.Count;
  196. k++)
  197. {
  198. info.temp[1].Add(nums[k]);
  199. }
  200. info.class_Count[0] = new double
  201. [leftCunt.Length];
  202. info.class_Count[1] = new double
  203. [leftCunt.Length];
  204. for (int k = 0; k < leftCunt.Length; k++)
  205. {
  206. info.class_Count[0][k] = leftCunt[k];
  207. info.class_Count[1][k] = rightCount
  208. [k];
  209. }
  210. }
  211. #endregion
  212. }
  213. #endregion
  214. }
  215. #region 没有寻找到最佳的分裂点,则设置为叶节点
  216. if (info.splitIndex == -1)
  217. {
  218. double[] finalCount = node.ClassCount;
  219. double max = finalCount[0];
  220. int result = 1;
  221. for (int i = 1; i < finalCount.Length; i++)
  222. {
  223. if (finalCount[i] > max)
  224. {
  225. max = finalCount[i];
  226. result = (i + 1);
  227. }
  228. }
  229. node.feature_Type="result";
  230. node.features=new List<String> { "" + result };
  231. return node;
  232. }
  233. #endregion
  234. #region 分裂
  235. int deep = node.deep;
  236. node.SplitFeature = ("" + info.splitIndex);
  237. List<Node> childNode = new List<Node>();
  238. int[][] used = new int[2][];
  239. used[0] = new int[isUsed.Length];
  240. used[1] = new int[isUsed.Length];
  241. for (int i = 0; i < isUsed.Length; i++)
  242. {
  243. used[0][i] = isUsed[i];
  244. used[1][i] = isUsed[i];
  245. }
  246. if (info.type == 0)
  247. {
  248. used[0][info.splitIndex] = 1;
  249. node.feature_Type = ("离散");
  250. }
  251. else
  252. {
  253. //used[info.splitIndex] = 0;
  254. node.feature_Type = ("连续");
  255. }
  256. List<int>[] rowIndex = info.temp;
  257. List<String> features = info.features;
  258. Node node1 = new Node();
  259. Node node2 = new Node();
  260. node1.setClassCount(info.class_Count[0]);
  261. node2.setClassCount(info.class_Count[1]);
  262. node1.rowCount = info.temp[0].Count;
  263. node2.rowCount = info.temp[1].Count;
  264. node1.deep = deep + 1;
  265. node2.deep = deep + 1;
  266. node1 = findBestSplit(node1, info.temp[0],used[0]);
  267. node2 = findBestSplit(node2, info.temp[1], used[1]);
  268. node.leafNode_Count = (node1.leafNode_Count
  269. +node2.leafNode_Count);
  270. node.leafWrong = (node1.leafWrong+node2.leafWrong);
  271. node.features = (features);
  272. childNode.Add(node1);
  273. childNode.Add(node2);
  274. node.childNodes = childNode;
  275. #endregion
  276. return node;
  277. }
  278. catch (Exception e)
  279. {
  280. Console.WriteLine(e.StackTrace);
  281. return node;
  282. }
  283. }
  284. 节点选择属性和分裂

(4)剪枝
代价复杂度剪枝方法(CCP):

  1. public static void getSeries(Node node)
  2. {
  3. Stack<Node> nodeStack = new Stack<Node>();
  4. if (node != null)
  5. {
  6. nodeStack.Push(node);
  7. }
  8. if (node.feature_Type == "result")
  9. return;
  10. List<Node> childs = node.childNodes;
  11. for (int i = 0; i < childs.Count; i++)
  12. {
  13. getSeries(node);
  14. }
  15. }
  16. CCP代价复杂度剪枝

CART全部核心代码:

  1. /// <summary>
  2. /// 判断是否还需要分裂
  3. /// </summary>
  4. /// <param name="node"></param>
  5. /// <returns></returns>
  6. public static bool ifEnd(Node node, double shang,int[] isUsed)
  7. {
  8. try
  9. {
  10. double[] count = node.ClassCount;
  11. int rowCount = node.rowCount;
  12. int maxResult = 0;
  13. double maxRate = 0;
  14. #region 数达到某一深度
  15. int deep = node.deep;
  16. if (deep >= 10)
  17. {
  18. maxResult = node.result + 1;
  19. node.feature_Type="result";
  20. node.features=new List<String>() { maxResult + ""
  21. };
  22. node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
  23. node.leafNode_Count=1;
  24. return true;
  25. }
  26. #endregion
  27. #region 纯度(其实跟后面的有点重了,记得要修改)
  28. //maxResult = 1;
  29. //for (int i = 1; i < count.Length; i++)
  30. //{
  31. // if (count[i] / rowCount >= 0.95)
  32. // {
  33. // node.feature_Type="result";
  34. // node.features=new List<String> { "" + (i +
  35. 1) };
  36. // node.leafNode_Count=1;
  37. // node.leafWrong=rowCount - Convert.ToInt32
  38. (count[i]);
  39. // return true;
  40. // }
  41. //}
  42. #endregion
  43. #region 熵为0
  44. if (shang == 0)
  45. {
  46. maxRate = count[0] / rowCount;
  47. maxResult = 1;
  48. for (int i = 1; i < count.Length; i++)
  49. {
  50. if (count[i] / rowCount >= maxRate)
  51. {
  52. maxRate = count[i] / rowCount;
  53. maxResult = i + 1;
  54. }
  55. }
  56. node.feature_Type="result";
  57. node.features=new List<String> { maxResult + ""
  58. };
  59. node.leafWrong=rowCount - Convert.ToInt32(count
  60. [maxResult - 1]);
  61. node.leafNode_Count=1;
  62. return true;
  63. }
  64. #endregion
  65. #region 属性已经分完
  66. //int[] isUsed = node.getUsed();
  67. bool flag = true;
  68. for (int i = 0; i < isUsed.Length - 1; i++)
  69. {
  70. if (isUsed[i] == 0)
  71. {
  72. flag = false;
  73. break;
  74. }
  75. }
  76. if (flag)
  77. {
  78. maxRate = count[0] / rowCount;
  79. maxResult = 1;
  80. for (int i = 1; i < count.Length; i++)
  81. {
  82. if (count[i] / rowCount >= maxRate)
  83. {
  84. maxRate = count[i] / rowCount;
  85. maxResult = i + 1;
  86. }
  87. }
  88. node.feature_Type=("result");
  89. node.features=(new List<String> { "" +
  90. (maxResult) });
  91. node.leafWrong=(rowCount - Convert.ToInt32(count
  92. [maxResult - 1]));
  93. node.leafNode_Count=(1);
  94. return true;
  95. }
  96. #endregion
  97. #region 几点数少于100
  98. if (rowCount < Limit_Node)
  99. {
  100. maxRate = count[0] / rowCount;
  101. maxResult = 1;
  102. for (int i = 1; i < count.Length; i++)
  103. {
  104. if (count[i] / rowCount >= maxRate)
  105. {
  106. maxRate = count[i] / rowCount;
  107. maxResult = i + 1;
  108. }
  109. }
  110. node.feature_Type="result";
  111. node.features=new List<String> { "" + (maxResult)
  112. };
  113. node.leafWrong=rowCount - Convert.ToInt32(count
  114. [maxResult - 1]);
  115. node.leafNode_Count=1;
  116. return true;
  117. }
  118. #endregion
  119. return false;
  120. }
  121. catch (Exception e)
  122. {
  123. return false;
  124. }
  125. }
  126. #region 排序算法
  127. public static void InsertSort(double[] values, List<int> arr,
  128. int StartIndex, int endIndex)
  129. {
  130. for (int i = StartIndex + 1; i <= endIndex; i++)
  131. {
  132. int key = arr[i];
  133. double init = values[i];
  134. int j = i - 1;
  135. while (j >= StartIndex && values[j] > init)
  136. {
  137. arr[j + 1] = arr[j];
  138. values[j + 1] = values[j];
  139. j--;
  140. }
  141. arr[j + 1] = key;
  142. values[j + 1] = init;
  143. }
  144. }
  145. static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
  146. {
  147. int mid = low + ((high - low) >> 1);//计算数组中间的元素的下标
  148. //使用三数取中法选择枢轴
  149. if (values[mid] > values[high])//目标: arr[mid] <= arr[high]
  150. {
  151. swap(values, arr, mid, high);
  152. }
  153. if (values[low] > values[high])//目标: arr[low] <= arr[high]
  154. {
  155. swap(values, arr, low, high);
  156. }
  157. if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]
  158. {
  159. swap(values, arr, mid, low);
  160. }
  161. //此时,arr[mid] <= arr[low] <= arr[high]
  162. return low;
  163. //low的位置上保存这三个位置中间的值
  164. //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了
  165. }
  166. static void swap(double[] values, List<int> arr, int t1, int t2)
  167. {
  168. double temp = values[t1];
  169. values[t1] = values[t2];
  170. values[t2] = temp;
  171. int key = arr[t1];
  172. arr[t1] = arr[t2];
  173. arr[t2] = key;
  174. }
  175. static void QSort(double[] values, List<int> arr, int low, int high)
  176. {
  177. int first = low;
  178. int last = high;
  179. int left = low;
  180. int right = high;
  181. int leftLen = 0;
  182. int rightLen = 0;
  183. if (high - low + 1 < 10)
  184. {
  185. InsertSort(values, arr, low, high);
  186. return;
  187. }
  188. //一次分割
  189. int key = SelectPivotMedianOfThree(values, arr, low,
  190. high);//使用三数取中法选择枢轴
  191. double inti = values[key];
  192. int currentKey = arr[key];
  193. while (low < high)
  194. {
  195. while (high > low && values[high] >= inti)
  196. {
  197. if (values[high] == inti)//处理相等元素
  198. {
  199. swap(values, arr, right, high);
  200. right--;
  201. rightLen++;
  202. }
  203. high--;
  204. }
  205. arr[low] = arr[high];
  206. values[low] = values[high];
  207. while (high > low && values[low] <= inti)
  208. {
  209. if (values[low] == inti)
  210. {
  211. swap(values, arr, left, low);
  212. left++;
  213. leftLen++;
  214. }
  215. low++;
  216. }
  217. arr[high] = arr[low];
  218. values[high] = values[low];
  219. }
  220. arr[low] = currentKey;
  221. values[low] = values[key];
  222. //一次快排结束
  223. //把与枢轴key相同的元素移到枢轴最终位置周围
  224. int i = low - 1;
  225. int j = first;
  226. while (j < left && values[i] != inti)
  227. {
  228. swap(values, arr, i, j);
  229. i--;
  230. j++;
  231. }
  232. i = low + 1;
  233. j = last;
  234. while (j > right && values[i] != inti)
  235. {
  236. swap(values, arr, i, j);
  237. i++;
  238. j--;
  239. }
  240. QSort(values, arr, first, low - 1 - leftLen);
  241. QSort(values, arr, low + 1 + rightLen, last);
  242. }
  243. #endregion
  244. /// <summary>
  245. /// 寻找最佳的分裂点
  246. /// </summary>
  247. /// <param name="num"></param>
  248. /// <param name="node"></param>
  249. public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
  250. {
  251. try
  252. {
  253. //判断是否继续分裂
  254. double totalShang = getGini(node.ClassCount, node.rowCount);
  255. if (ifEnd(node, totalShang, isUsed))
  256. {
  257. return node;
  258. }
  259. #region 变量声明
  260. SplitInfo info = new SplitInfo();
  261. info.initial();
  262. int RowCount = nums.Count; //样本总数
  263. double jubuMax = 1; //局部最大熵
  264. int splitPoint = 0; //分裂的点
  265. double splitValue = 0; //分裂的值
  266. #endregion
  267. for (int i = 0; i < isUsed.Length - 1; i++)
  268. {
  269. if (isUsed[i] == 1)
  270. {
  271. continue;
  272. }
  273. #region 离散变量
  274. if (type[i] == 0)
  275. {
  276. double[][] allCount = new double[allNum[i]][];
  277. for (int j = 0; j < allCount.Length; j++)
  278. {
  279. allCount[j] = new double[classCount];
  280. }
  281. int[] countAllFeature = new int[allNum[i]];
  282. List<int>[] temp = new List<int>[allNum[i]];
  283. double[] allClassCount = node.ClassCount; //所有类别的数量
  284. for (int j = 0; j < temp.Length; j++)
  285. {
  286. temp[j] = new List<int>();
  287. }
  288. for (int j = 0; j < nums.Count; j++)
  289. {
  290. int index = Convert.ToInt32(allData[nums[j]][i]);
  291. temp[index - 1].Add(nums[j]);
  292. countAllFeature[index - 1]++;
  293. allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
  294. }
  295. double allShang = 1;
  296. int choose = 0;
  297. double[][] jubuCount = new double[2][];
  298. for (int k = 0; k < allCount.Length; k++)
  299. {
  300. if (temp[k].Count == 0)
  301. continue;
  302. double JubuShang = 0;
  303. double[][] tempCount = new double[2][];
  304. tempCount[0] = allCount[k];
  305. tempCount[1] = new double[allCount[0].Length];
  306. for (int j = 0; j < tempCount[1].Length; j++)
  307. {
  308. tempCount[1][j] = allClassCount[j] - allCount[k][j];
  309. }
  310. JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
  311. int nodecount = RowCount - countAllFeature[k];
  312. JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
  313. if (JubuShang < allShang)
  314. {
  315. allShang = JubuShang;
  316. jubuCount = tempCount;
  317. choose = k;
  318. }
  319. }
  320. if (allShang < jubuMax)
  321. {
  322. info.type = 0;
  323. jubuMax = allShang;
  324. info.class_Count = jubuCount;
  325. info.temp[0] = temp[choose];
  326. info.temp[1] = new List<int>();
  327. info.features = new List<string>();
  328. info.features.Add((choose + 1) + "");
  329. info.features.Add("");
  330. for (int j = 0; j < temp.Length; j++)
  331. {
  332. if (j == choose)
  333. continue;
  334. for (int k = 0; k < temp[j].Count; k++)
  335. {
  336. info.temp[1].Add(temp[j][k]);
  337. }
  338. if (temp[j].Count != 0)
  339. {
  340. info.features[1] = info.features[1] + (j + 1) + ",";
  341. }
  342. }
  343. info.splitIndex = i;
  344. }
  345. }
  346. #endregion
  347. #region 连续变量
  348. else
  349. {
  350. double[] leftCunt = new double[classCount];
  351. //做节点各个类别的数量
  352. double[] rightCount = new double[classCount];
  353. //右节点各个类别的数量
  354. double[] count1 = new double[classCount];
  355. //子集1的统计量
  356. double[] count2 = new double
  357. [node.ClassCount.Length]; //子集2的统计量
  358. for (int j = 0; j < node.ClassCount.Length;
  359. j++)
  360. {
  361. count2[j] = node.ClassCount[j];
  362. }
  363. int all1 = 0;
  364. //子集1的样本量
  365. int all2 = nums.Count;
  366. //子集2的样本量
  367. double lastValue = 0;
  368. //上一个记录的类别
  369. double currentValue = 0;
  370. //当前类别
  371. double lastPoint = 0;
  372. //上一个点的值
  373. double currentPoint = 0;
  374. //当前点的值
  375. double[] values = new double[nums.Count];
  376. for (int j = 0; j < values.Length; j++)
  377. {
  378. values[j] = allData[nums[j]][i];
  379. }
  380. QSort(values, nums, 0, nums.Count - 1);
  381. double lianxuMax = 1;
  382. //连续型属性的最大熵
  383. #region 寻找最佳的分割点
  384. for (int j = 0; j < nums.Count - 1; j++)
  385. {
  386. currentValue = allData[nums[j]][lieshu -
  387. 1];
  388. currentPoint = (allData[nums[j]][i]);
  389. if (j == 0)
  390. {
  391. lastValue = currentValue;
  392. lastPoint = currentPoint;
  393. }
  394. if (currentValue != lastValue &&
  395. currentPoint != lastPoint)
  396. {
  397. double shang1 = getGini(count1,
  398. all1);
  399. double shang2 = getGini(count2,
  400. all2);
  401. double allShang = shang1 * all1 /
  402. (all1 + all2) + shang2 * all2 / (all1 + all2);
  403. //allShang = (totalShang - allShang);
  404. if (lianxuMax > allShang)
  405. {
  406. lianxuMax = allShang;
  407. for (int k = 0; k <
  408. count1.Length; k++)
  409. {
  410. leftCunt[k] = count1[k];
  411. rightCount[k] = count2[k];
  412. }
  413. splitPoint = j;
  414. splitValue = (currentPoint +
  415. lastPoint) / 2;
  416. }
  417. }
  418. all1++;
  419. count1[Convert.ToInt32(currentValue) -
  420. 1]++;
  421. count2[Convert.ToInt32(currentValue) -
  422. 1]--;
  423. all2--;
  424. lastValue = currentValue;
  425. lastPoint = currentPoint;
  426. }
  427. #endregion
  428. #region 如果超过了局部值,重设
  429. if (lianxuMax < jubuMax)
  430. {
  431. info.type = 1;
  432. info.splitIndex = i;
  433. info.features=new List<string>()
  434. {splitValue+""};
  435. //finalPoint = splitPoint;
  436. jubuMax = lianxuMax;
  437. info.temp[0] = new List<int>();
  438. info.temp[1] = new List<int>();
  439. for (int k = 0; k < splitPoint; k++)
  440. {
  441. info.temp[0].Add(nums[k]);
  442. }
  443. for (int k = splitPoint; k < nums.Count;
  444. k++)
  445. {
  446. info.temp[1].Add(nums[k]);
  447. }
  448. info.class_Count[0] = new double
  449. [leftCunt.Length];
  450. info.class_Count[1] = new double
  451. [leftCunt.Length];
  452. for (int k = 0; k < leftCunt.Length; k++)
  453. {
  454. info.class_Count[0][k] = leftCunt[k];
  455. info.class_Count[1][k] = rightCount
  456. [k];
  457. }
  458. }
  459. #endregion
  460. }
  461. #endregion
  462. }
  463. #region 没有寻找到最佳的分裂点,则设置为叶节点
  464. if (info.splitIndex == -1)
  465. {
  466. double[] finalCount = node.ClassCount;
  467. double max = finalCount[0];
  468. int result = 1;
  469. for (int i = 1; i < finalCount.Length; i++)
  470. {
  471. if (finalCount[i] > max)
  472. {
  473. max = finalCount[i];
  474. result = (i + 1);
  475. }
  476. }
  477. node.feature_Type="result";
  478. node.features=new List<String> { "" + result };
  479. return node;
  480. }
  481. #endregion
  482. #region 分裂
  483. int deep = node.deep;
  484. node.SplitFeature = ("" + info.splitIndex);
  485. List<Node> childNode = new List<Node>();
  486. int[][] used = new int[2][];
  487. used[0] = new int[isUsed.Length];
  488. used[1] = new int[isUsed.Length];
  489. for (int i = 0; i < isUsed.Length; i++)
  490. {
  491. used[0][i] = isUsed[i];
  492. used[1][i] = isUsed[i];
  493. }
  494. if (info.type == 0)
  495. {
  496. used[0][info.splitIndex] = 1;
  497. node.feature_Type = ("离散");
  498. }
  499. else
  500. {
  501. //used[info.splitIndex] = 0;
  502. node.feature_Type = ("连续");
  503. }
  504. List<int>[] rowIndex = info.temp;
  505. List<String> features = info.features;
  506. Node node1 = new Node();
  507. Node node2 = new Node();
  508. node1.setClassCount(info.class_Count[0]);
  509. node2.setClassCount(info.class_Count[1]);
  510. node1.rowCount = info.temp[0].Count;
  511. node2.rowCount = info.temp[1].Count;
  512. node1.deep = deep + 1;
  513. node2.deep = deep + 1;
  514. node1 = findBestSplit(node1, info.temp[0],used[0]);
  515. node2 = findBestSplit(node2, info.temp[1], used[1]);
  516. node.leafNode_Count = (node1.leafNode_Count
  517. +node2.leafNode_Count);
  518. node.leafWrong = (node1.leafWrong+node2.leafWrong);
  519. node.features = (features);
  520. childNode.Add(node1);
  521. childNode.Add(node2);
  522. node.childNodes = childNode;
  523. #endregion
  524. return node;
  525. }
  526. catch (Exception e)
  527. {
  528. Console.WriteLine(e.StackTrace);
  529. return node;
  530. }
  531. }
  532. /// <summary>
  533. /// GINI值
  534. /// </summary>
  535. /// <param name="counts"></param>
  536. /// <param name="countAll"></param>
  537. /// <returns></returns>
  538. public static double getGini(double[] counts, int countAll)
  539. {
  540. double Gini = 1;
  541. for (int i = 0; i < counts.Length; i++)
  542. {
  543. Gini = Gini - Math.Pow(counts[i] / countAll, 2);
  544. }
  545. return Gini;
  546. }
  547. #region CCP剪枝
  548. public static void getSeries(Node node)
  549. {
  550. Stack<Node> nodeStack = new Stack<Node>();
  551. if (node != null)
  552. {
  553. nodeStack.Push(node);
  554. }
  555. if (node.feature_Type == "result")
  556. return;
  557. List<Node> childs = node.childNodes;
  558. for (int i = 0; i < childs.Count; i++)
  559. {
  560. getSeries(node);
  561. }
  562. }
  563. /// <summary>
  564. /// 遍历剪枝
  565. /// </summary>
  566. /// <param name="node"></param>
  567. public static Node getNode1(Node node, Node nodeCut)
  568. {
  569. //List<Node> childNodes = node.getChild();
  570. //double min = 100000;
  571. ////Node nodeCut = new Node();
  572. //double temp = 0;
  573. //for (int i = 0; i < childNodes.Count; i++)
  574. //{
  575. // if (childNodes[i].getType() != "result")
  576. // {
  577. // //if (!cutTree(childNodes[i]))
  578. // temp = min;
  579. // min = cutTree(childNodes[i], min);
  580. // if (min < temp)
  581. // nodeCut = childNodes[i];
  582. // getNode1(childNodes[i], nodeCut);
  583. // }
  584. //}
  585. //node.setChildNode(childNodes);
  586. return null;
  587. }
  588. /// <summary>
  589. /// 对每一个节点剪枝
  590. /// </summary>
  591. public static double cutTree(Node node, double minA)
  592. {
  593. int rowCount = node.rowCount;
  594. double leaf = node.getErrorCount();
  595. double[] values = getError1(node, 0, 0);
  596. double treeWrong = values[0];
  597. double son = values[1];
  598. double rate = (leaf - treeWrong) / (son - 1);
  599. if (minA > rate)
  600. minA = rate;
  601. //double var = Math.Sqrt(treeWrong * (1 - treeWrong /
  602. rowCount));
  603. //double panbie = treeWrong + var - leaf;
  604. //if (panbie > 0)
  605. //{
  606. // node.setFeatureType("result");
  607. // node.setChildNode(null);
  608. // int result = (node.getResult() + 1);
  609. // node.setFeatures(new List<String>() { "" + result
  610. });
  611. // //return true;
  612. //}
  613. return minA;
  614. }
  615. /// <summary>
  616. /// 获得子树的错误个数
  617. /// </summary>
  618. /// <param name="node"></param>
  619. /// <returns></returns>
  620. public static double[] getError1(Node node, double treeError,
  621. double son)
  622. {
  623. if (node.feature_Type == "result")
  624. {
  625. double error = node.getErrorCount();
  626. son++;
  627. return new double[] { treeError + error, son };
  628. }
  629. List<Node> childNode = node.childNodes;
  630. for (int i = 0; i < childNode.Count; i++)
  631. {
  632. double[] values = getError1(childNode[i], treeError,
  633. son);
  634. treeError = values[0];
  635. son = values[1];
  636. }
  637. return new double[] { treeError, son };
  638. }
  639. #endregion
  640. CART核心代码

总结

**

  • (1)CART是一棵二叉树,每一次分裂会产生两个子节点,对于连续性的数据,直接采用与C4.5相似的处理方法,对于离散型数据,选择最优的两种离散值组合方法。
  • (2)CART既能是分类数,又能是二叉树。如果是分类树,将选择能够最小化分裂后节点GINI值的分裂属性;如果是回归树,选择能够最小化两个节点样本方差的分裂属性。
  • (3)CART跟C4.5一样,需要进行剪枝,采用CCP(代价复杂度的剪枝方法)。