如上一篇文章所述,ID3方法主要有几个缺点:一是采用信息增益进行数据分裂,准确性不如信息增益率;二是不能对连续数据进行处理,只能通过连续数据离散化进行处理;三是没有采用剪枝的策略,决策树的结构可能会过于复杂,可能会出现过拟合的情况。

C4.5在ID3的基础上对上述三个方面进行了相应的改进:
a) C4.5对节点进行分裂时采用信息增益率作为分裂的依据;
b) 能够对连续数据进行处理;
c) C4.5采用剪枝的策略,对完全生长的决策树进行剪枝处理,一定程度上降低过拟合的影响。
**

1.采用信息增益率作为分裂的依据

信息增益率的计算公式为:
决策树系列(四)——C4.5 - 图1

其中决策树系列(四)——C4.5 - 图2表示信息增益,决策树系列(四)——C4.5 - 图3表示分裂子节点数据量的信息增益,计算公式为:
决策树系列(四)——C4.5 - 图4

其中m表示节点的数量,Ni表示第i个节点的数据量,N表示父亲节点的数据量,说白了,决策树系列(四)——C4.5 - 图5其实是分裂节点的熵。

信息增益率越大,说明分裂的效果越好。
以一个实际的例子说明C4.5如何通过信息增益率选择分裂的属性:

表1 原始数据表

当天天气 温度 湿度 日期 逛街
25 50 工作日
21 48 工作日
18 70 周末
28 41 周末
8 65 工作日
18 43 工作日
24 56 周末
18 76 周末
31 61 周末
6 43 周末
15 55 工作日
4 58 工作日

以当天天气为例:

一共有三个属性值,晴、阴、雨,一共分裂成三个子节点。
决策树系列(四)——C4.5 - 图6

根据上述公式,可以计算信息增益率如下:
决策树系列(四)——C4.5 - 图7

所以使用天气属性进行分裂可以得到信息增益率0.44。
**

2.对连续型属性进行处理

C4.5处理离散型属性的方式与ID3一致,新增对连续型属性的处理。处理方式是先根据连续型属性进行排序,然后采用一刀切的方式将数据砍成两半。

那么如何选择切割点呢?很简单,直接计算每一个切割点切割后的信息增益,然后选择使分裂效果最优的切割点。

以温度为例:
决策树系列(四)——C4.5 - 图8

从上图可以看出,理论上来讲,N条数据就有N-1个切割点,为了选取最优的切割垫,要计算按每一次切割的信息增益,计算量是比较大的,那么有没有简化的方法呢?有,注意到,其实有些切割点是很明显可以排除的。比如说上图右侧的第2条和第3条记录,两者的类标签(逛街)都是“是”,如果从这里切割的话,就将两个本来相同的类分开了,肯定不会比将他们归为一类的切分方法好,因此,可以通过去除前后两个类标签相同的切割点以简化计算的复杂度,如下图所示:

决策树系列(四)——C4.5 - 图9

从图中可以看出,最终切割点的数目从原来的11个减少到现在的6个,降低了计算的复杂度。

确定了分割点之后,接下来就是选择最优的分割点了,注意,对连续型属性是采用信息增益进行内部择优的,因为如果使用信息增益率进行分裂会出现倾向于选择分割前后两个节点数据量相差最大的分割点,为了避免这种情况,选择信息增益选择分割点。选择了最优的分割点之后,再计算信息增益率跟其他的属性进行比较,确定最优的分裂属性。

3. 剪枝

决策树只已经提到,剪枝是在完全生长的决策树的基础上,对生长后分类效果不佳的子树进行修剪,减小决策树的复杂度,降低过拟合的影响。

C4.5采用悲观剪枝方法(PEP)。悲观剪枝认为如果决策树的精度在剪枝前后没有影响的话,则进行剪枝。怎样才算是没有影响?如果剪枝后的误差小于剪枝前经度的上限,则说明剪枝后的效果与更佳,此时需要子树进行剪枝操作。

进行剪枝必须满足的条件:
决策树系列(四)——C4.5 - 图10

其中:
决策树系列(四)——C4.5 - 图11表示子树的误差;
决策树系列(四)——C4.5 - 图12 表示叶子节点的误差;

令子树误差的经度满足二项分布,根据二项分布的性质,决策树系列(四)——C4.5 - 图13决策树系列(四)——C4.5 - 图14,其中决策树系列(四)——C4.5 - 图15,N为子树的数据量;同样,叶子节点的误差决策树系列(四)——C4.5 - 图16

上述公式中,0.5表示修正因子。由于对父节点进行分裂总会得到比父节点分类结果更好的效果,因此,因此从理论上来说,父节点的误差总是不小于孩子节点的误差,因此需要进行修正,给每一个节点都加上0.5的修正因此,在计算误差的时候,子节点由于加上了修正的因子,就无法保证总误差总是低于父节点。

算例:
决策树系列(四)——C4.5 - 图17
决策树系列(四)——C4.5 - 图18

由于决策树系列(四)——C4.5 - 图19,所以应该进行剪枝。

程序设计及源代码(C#版)

程序的设计过程

(1)数据格式

对原始的数据进行数字化处理,并以二维数据的形式存储,每一行表示一条记录,前n-1列表示属性,最后一列表示分类的标签。

如表1的数据可以转化为表2:

表2 初始化后的数据**

当天天气 温度 湿度 季节 明天天气
1 25 50 1 1
2 21 48 1 2
2 18 70 1 3
1 28 41 2 1
3 8 65 3 2
1 18 43 2 1
2 24 56 4 1
3 18 76 4 2
3 31 61 2 1
2 6 43 3 3
1 15 55 4 2
3 4 58 3 3

其中,对于“当天天气”属性,数字{1,2,3}分别表示{晴,阴,雨};对于“季节”属性{1,2,3,4}分别表示{春天、夏天、冬天、秋天};对于类标签“明天天气”,数字{1,2,3}分别表示{晴、阴、雨}。
代码如下所示:

static double[][] allData; //存储进行训练的数据
static List[] featureValues; //离散属性对应的离散值
featureValues是链表数组,数组的长度为属性的个数,数组的每个元素为该属性的离散值链表。

(2)两个类:节点类和分裂信息

a)节点类Node

该类表示一个节点,属性包括节点选择的分裂属性、节点的输出类、孩子节点、深度等。注意,与ID3中相比,新增了两个属性:leafWrong和leafNode_Count分别表示叶子节点的总分类误差和叶子节点的个数,主要是为了方便剪枝。

  1. class Node
  2. {
  3. /// <summary>
  4. /// 各个子节点对应的取值
  5. /// </summary>
  6. //public List<String> features;
  7. public List<String> features{get;set;}
  8. /// <summary>
  9. /// 分裂属性的数据类型(1:连续 0:离散)
  10. /// </summary>
  11. public String feature_Type {get;set;}
  12. /// <summary>
  13. /// 分裂属性列的下标
  14. /// </summary>
  15. public String SplitFeature {get;set;}
  16. /// <summary>
  17. /// 各类别的数量统计
  18. /// </summary>
  19. public double[] ClassCount {get;set;}
  20. /// <summary>
  21. /// 数据量
  22. /// </summary>
  23. public int rowCount { get; set; }
  24. /// <summary>
  25. /// 各个子节点
  26. /// </summary>
  27. public List<Node> childNodes {get;set;}
  28. /// <summary>
  29. /// 父亲节点
  30. /// </summary>
  31. public Node Parent {get;set;}
  32. /// <summary>
  33. /// 该节点占比最大的类别
  34. /// </summary>
  35. public String finalResult {get;set;}
  36. /// <summary>
  37. /// 数的深度
  38. /// </summary>
  39. public int deep {get;set;}
  40. /// <summary>
  41. /// 节点占比最大类的标号
  42. /// </summary>
  43. public int result {get;set;}
  44. /// <summary>
  45. /// 子节点的错误数
  46. /// </summary>
  47. public int leafWrong {get;set;}
  48. /// <summary>
  49. /// 子节点的数目
  50. /// </summary>
  51. public int leafNode_Count {get;set;}
  52. public double getErrorCount()
  53. {
  54. return rowCount - ClassCount[result];
  55. }
  56. #region
  57. public void setClassCount(double[] count)
  58. {
  59. this.ClassCount = count;
  60. double max = ClassCount[0];
  61. int result = 0;
  62. for (int i = 1; i < ClassCount.Length; i++)
  63. {
  64. if (max < ClassCount[i])
  65. {
  66. max = ClassCount[i];
  67. result = i;
  68. }
  69. }
  70. this.result = result;
  71. }
  72. #endregion
  73. }

b)分裂信息类

该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。

  1. class SplitInfo
  2. {
  3. /// <summary>
  4. /// 分裂的属性下标
  5. /// </summary>
  6. public int splitIndex { get; set; }
  7. /// <summary>
  8. /// 数据类型
  9. /// </summary>
  10. public int type { get; set; }
  11. /// <summary>
  12. /// 分裂属性的取值
  13. /// </summary>
  14. public List<String> features { get; set; }
  15. /// <summary>
  16. /// 各个节点的行坐标链表
  17. /// </summary>
  18. public List<int>[] temp { get; set; }
  19. /// <summary>
  20. /// 每个节点各类的数目
  21. /// </summary>
  22. public double[][] class_Count { get; set; }
  23. }

主方法findBestSplit(Node node,List nums,int[] isUsed),该方法对节点进行分裂

其中:

  • node表示即将进行分裂的节点;
  • nums表示节点数据的行坐标列表;
  • isUsed表示到该节点位置所有属性的使用情况;

findBestSplit的这个方法主要有以下几个组成部分:

1)节点分裂停止的判定
节点分裂条件如上文所述,源代码如下:

  1. public static bool ifEnd(Node node, double entropy,int[] isUsed)
  2. {
  3. try
  4. {
  5. double[] count = node.ClassCount;
  6. int rowCount = node.rowCount;
  7. int maxResult = 0;
  8. #region 数达到某一深度
  9. int deep = node.deep;
  10. if (deep >= maxDeep)
  11. {
  12. maxResult = node.result + 1;
  13. node.feature_Type=("result");
  14. node.features=(new List<String>() { maxResult + "" });
  15. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  16. node.leafNode_Count = 1;
  17. return true;
  18. }
  19. #endregion
  20. #region 纯度(其实跟后面的有点重了,记得要修改)
  21. //maxResult = 1;
  22. //for (int i = 1; i < count.Length; i++)
  23. //{
  24. // if (count[i] / rowCount >= 0.95)
  25. // {
  26. // node.feature_Type=("result");
  27. // node.features=(new List<String> { "" + (i + 1) });
  28. // node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  29. // node.leafNode_Count = 1;
  30. // return true;
  31. // }
  32. //}
  33. #endregion
  34. #region 熵为0
  35. if (entropy == 0)
  36. {
  37. maxResult = node.result+1;
  38. node.feature_Type=("result");
  39. node.features=(new List<String> { maxResult + "" });
  40. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  41. node.leafNode_Count = 1;
  42. return true;
  43. }
  44. #endregion
  45. #region 属性已经分完
  46. bool flag = true;
  47. for (int i = 0; i < isUsed.Length - 1; i++)
  48. {
  49. if (isUsed[i] == 0)
  50. {
  51. flag = false;
  52. break;
  53. }
  54. }
  55. if (flag)
  56. {
  57. maxResult = node.result+1;
  58. node.feature_Type=("result");
  59. node.features=(new List<String> { "" + (maxResult) });
  60. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  61. node.leafNode_Count = 1;
  62. return true;
  63. }
  64. #endregion
  65. #region 数据量少于100
  66. if (rowCount < Limit_Node)
  67. {
  68. maxResult = node.result+1;
  69. node.feature_Type=("result");
  70. node.features=(new List<String> { "" + (maxResult) });
  71. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  72. node.leafNode_Count = 1;
  73. return true;
  74. }
  75. #endregion
  76. return false;
  77. }
  78. catch (Exception e)
  79. {
  80. return false;
  81. }
  82. }

2)寻找最优的分裂属性
寻找最优的分裂属性需要计算每一个分裂属性分裂后的信息增益率,计算公式上文已给出,其中熵的计算代码如下:

  1. public static double CalEntropy(double[] counts, int countAll)
  2. {
  3. try
  4. {
  5. double allShang = 0;
  6. for (int i = 0; i < counts.Length; i++)
  7. {
  8. if (counts[i] == 0)
  9. {
  10. continue;
  11. }
  12. double rate = counts[i] / countAll;
  13. allShang = allShang + rate * Math.Log(rate, 2);
  14. }
  15. return allShang;
  16. }
  17. catch (Exception e)
  18. {
  19. return 0;
  20. }
  21. }

3)进行分裂,同时对子节点进行迭代处理
其实就是递归的工程,对每一个子节点执行findBestSplit方法进行分裂。
findBestSplit源代码:

  1. public static Node findBestSplit(Node node, List<int> nums, int[] isUsed)
  2. {
  3. try
  4. {
  5. //判断是否继续分裂
  6. double totalShang = CalEntropy(node.ClassCount, node.rowCount);
  7. if (ifEnd(node, totalShang,isUsed))
  8. {
  9. return node;
  10. }
  11. #region 变量声明
  12. SplitInfo info = new SplitInfo();
  13. int RowCount = nums.Count; //样本总数
  14. double jubuMax = 0; //局部最大熵
  15. #endregion
  16. for (int i = 0; i < isUsed.Length - 1; i++)
  17. {
  18. if (isUsed[i] == 1)
  19. {
  20. continue;
  21. }
  22. #region 离散变量
  23. if (type[i] == 0)
  24. {
  25. int[] allFeatureCount = new int[0]; //所有类别的数量
  26. double[][] allCount = new double[allNum[i]][];
  27. for (int j = 0; j < allCount.Length; j++)
  28. {
  29. allCount[j] = new double[classCount];
  30. }
  31. int[] countAllFeature = new int[allNum[i]];
  32. List<int>[] temp = new List<int>[allNum[i]];
  33. for (int j = 0; j < temp.Length; j++)
  34. {
  35. temp[j] = new List<int>();
  36. }
  37. for (int j = 0; j < nums.Count; j++)
  38. {
  39. int index = Convert.ToInt32(allData[nums[j]][i]);
  40. temp[index - 1].Add(nums[j]);
  41. countAllFeature[index - 1]++;
  42. allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
  43. }
  44. double allShang = 0;
  45. double chushu = 0;
  46. for (int j = 0; j < allCount.Length; j++)
  47. {
  48. allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
  49. if (countAllFeature[j] > 0)
  50. {
  51. double rate = countAllFeature[j] / Convert.ToDouble(RowCount);
  52. chushu = chushu + rate * Math.Log(rate, 2);
  53. }
  54. }
  55. allShang = (-totalShang + allShang);
  56. if (allShang > jubuMax)
  57. {
  58. info.features = new List<string>();
  59. info.type = 0;
  60. info.temp = temp;
  61. info.splitIndex = i;
  62. info.class_Count = allCount;
  63. jubuMax = allShang;
  64. allFeatureCount = countAllFeature;
  65. }
  66. }
  67. #endregion
  68. #region 连续变量
  69. else
  70. {
  71. double[] leftCount = new double[classCount]; //做节点各个类别的数量
  72. double[] rightCount = new double[classCount]; //右节点各个类别的数量
  73. double[] count1 = new double[classCount]; //子集1的统计量
  74. //double[] count2 = new double[node.getCount().Length]; //子集2的统计量
  75. double[] count2 = new double[node.ClassCount.Length]; //子集2的统计量
  76. for (int j = 0; j < node.ClassCount.Length; j++)
  77. {
  78. count2[j] = node.ClassCount[j];
  79. }
  80. int all1 = 0; //子集1的样本量
  81. int all2 = nums.Count; //子集2的样本量
  82. double lastValue = 0; //上一个记录的类别
  83. double currentValue = 0; //当前类别
  84. double lastPoint = 0; //上一个点的值
  85. double currentPoint = 0; //当前点的值
  86. int splitPoint = 0;
  87. double splitValue = 0;
  88. double[] values = new double[nums.Count];
  89. for (int j = 0; j < values.Length; j++)
  90. {
  91. values[j] = allData[nums[j]][i];
  92. }
  93. QSort(values, nums, 0, nums.Count - 1);
  94. double chushu = 0;
  95. double lianxuMax = 0; //连续型属性的最大熵
  96. for (int j = 0; j < nums.Count - 1; j++)
  97. {
  98. currentValue = allData[nums[j]][lieshu - 1];
  99. currentPoint = allData[nums[j]][i];
  100. if (j == 0)
  101. {
  102. lastValue = currentValue;
  103. lastPoint = currentPoint;
  104. }
  105. if (currentValue != lastValue)
  106. {
  107. double shang1 = CalEntropy(count1, all1);
  108. double shang2 = CalEntropy(count2, all2);
  109. double allShang = shang1 * all1 / (all1 + all2) + shang2 * all2 / (all1 + all2);
  110. allShang = (-totalShang + allShang);
  111. if (lianxuMax < allShang)
  112. {
  113. lianxuMax = allShang;
  114. for (int k = 0; k < count1.Length; k++)
  115. {
  116. leftCount[k] = count1[k];
  117. rightCount[k] = count2[k];
  118. }
  119. splitPoint = j;
  120. splitValue = (currentPoint + lastPoint) / 2;
  121. }
  122. }
  123. all1++;
  124. count1[Convert.ToInt32(currentValue) - 1]++;
  125. count2[Convert.ToInt32(currentValue) - 1]--;
  126. all2--;
  127. lastValue = currentValue;
  128. lastPoint = currentPoint;
  129. }
  130. double rate1 = Convert.ToDouble(leftCount[0] + leftCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
  131. chushu = 0;
  132. if (rate1 > 0)
  133. {
  134. chushu = chushu + rate1 * Math.Log(rate1, 2);
  135. }
  136. double rate2 = Convert.ToDouble(rightCount[0] + rightCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
  137. if (rate2 > 0)
  138. {
  139. chushu = chushu + rate2 * Math.Log(rate2, 2);
  140. }
  141. //lianxuMax = lianxuMax ;
  142. //lianxuMax = lianxuMax;
  143. if (lianxuMax > jubuMax)
  144. {
  145. //info.setSplitIndex(i);
  146. info.splitIndex=(i);
  147. //info.setFeatures(new List<String> { splitValue + "" });
  148. info.features = (new List<String> { splitValue + "" });
  149. //info.setType(1);
  150. info.type=(1);
  151. jubuMax = lianxuMax;
  152. //info.setType(1);
  153. List<int>[] allInt = new List<int>[2];
  154. allInt[0] = new List<int>();
  155. allInt[1] = new List<int>();
  156. for (int k = 0; k < splitPoint; k++)
  157. {
  158. allInt[0].Add(nums[k]);
  159. }
  160. for (int k = splitPoint; k < nums.Count; k++)
  161. {
  162. allInt[1].Add(nums[k]);
  163. }
  164. info.temp=(allInt);
  165. //info.setTemp(allInt);
  166. double[][] alls = new double[2][];
  167. alls[0] = new double[leftCount.Length];
  168. alls[1] = new double[leftCount.Length];
  169. for (int k = 0; k < leftCount.Length; k++)
  170. {
  171. alls[0][k] = leftCount[k];
  172. alls[1][k] = rightCount[k];
  173. }
  174. info.class_Count=(alls);
  175. //info.setclassCount(alls);
  176. }
  177. }
  178. #endregion
  179. }
  180. #region 如果找不到最佳的分裂属性,则设为叶节点
  181. if (info.splitIndex == -1)
  182. {
  183. double[] finalCount = node.ClassCount;
  184. double max = finalCount[0];
  185. int result = 1;
  186. for (int i = 1; i < finalCount.Length; i++)
  187. {
  188. if (finalCount[i] > max)
  189. {
  190. max = finalCount[i];
  191. result = (i + 1);
  192. }
  193. }
  194. node.feature_Type=("result");
  195. node.features=(new List<String> { "" + result });
  196. return node;
  197. }
  198. #endregion
  199. #region 分裂
  200. int deep = node.deep;
  201. node.SplitFeature=("" + info.splitIndex);
  202. List<Node> childNode = new List<Node>();
  203. int[] used = new int[isUsed.Length];
  204. for (int i = 0; i < used.Length; i++)
  205. {
  206. used[i] = isUsed[i];
  207. }
  208. if (info.type == 0)
  209. {
  210. used[info.splitIndex] = 1;
  211. node.feature_Type=("离散");
  212. }
  213. else
  214. {
  215. used[info.splitIndex] = 0;
  216. node.feature_Type=("连续");
  217. }
  218. int sumLeaf = 0;
  219. int sumWrong = 0;
  220. List<int>[] rowIndex = info.temp;
  221. List<String> features = info.features;
  222. for (int j = 0; j < rowIndex.Length; j++)
  223. {
  224. if (rowIndex[j].Count == 0)
  225. {
  226. continue;
  227. }
  228. if (info.type == 0)
  229. features.Add("" + (j + 1));
  230. Node node1 = new Node();
  231. node1.setClassCount(info.class_Count[j]);
  232. node1.deep=(deep + 1);
  233. node1.rowCount = info.temp[j].Count;
  234. node1 = findBestSplit(node1, info.temp[j], used);
  235. sumLeaf += node1.leafNode_Count;
  236. sumWrong += node1.leafWrong;
  237. childNode.Add(node1);
  238. }
  239. node.leafNode_Count = (sumLeaf);
  240. node.leafWrong = (sumWrong);
  241. node.features=(features);
  242. node.childNodes=(childNode);
  243. #endregion
  244. return node;
  245. }
  246. catch (Exception e)
  247. {
  248. Console.WriteLine(e.StackTrace);
  249. return node;
  250. }
  251. }

(4)剪枝
悲观剪枝方法(PEP):

  1. public static void prune(Node node)
  2. {
  3. if (node.feature_Type == "result")
  4. return;
  5. double treeWrong = node.getErrorCount() + 0.5;
  6. double leafError = node.leafWrong + 0.5 * node.leafNode_Count;
  7. double var = Math.Sqrt(leafError * (1 - Convert.ToDouble(leafError) / node.nums.Count));
  8. double panbie = leafError + var - treeWrong;
  9. if (panbie > 0)
  10. {
  11. node.feature_Type=("result");
  12. node.childNodes=(null);
  13. int result = (node.result + 1);
  14. node.features=(new List<String>() { "" + result });
  15. }
  16. else
  17. {
  18. List<Node> childNodes = node.childNodes;
  19. for (int i = 0; i < childNodes.Count; i++)
  20. {
  21. prune(childNodes[i]);
  22. }
  23. }
  24. }

C4.5核心算法的所有源代码:

  1. #region C4.5核心算法
  2. /// <summary>
  3. /// 测试
  4. /// </summary>
  5. /// <param name="node"></param>
  6. /// <param name="data"></param>
  7. public static String findResult(Node node, String[] data)
  8. {
  9. List<String> featrues = node.features;
  10. String type = node.feature_Type;
  11. if (type == "result")
  12. {
  13. return featrues[0];
  14. }
  15. int split = Convert.ToInt32(node.SplitFeature);
  16. List<Node> childNodes = node.childNodes;
  17. double[] resultCount = node.ClassCount;
  18. if (type == "连续")
  19. {
  20. double value = Convert.ToDouble(featrues[0]);
  21. if (Convert.ToDouble(data[split]) <= value)
  22. {
  23. return findResult(childNodes[0], data);
  24. }
  25. else
  26. {
  27. return findResult(childNodes[1], data);
  28. }
  29. }
  30. else
  31. {
  32. for (int i = 0; i < featrues.Count; i++)
  33. {
  34. if (data[split] == featrues[i])
  35. {
  36. return findResult(childNodes[i], data);
  37. }
  38. if (i == featrues.Count - 1)
  39. {
  40. double count = resultCount[0];
  41. int maxInt = 0;
  42. for (int j = 1; j < resultCount.Length; j++)
  43. {
  44. if (count < resultCount[j])
  45. {
  46. count = resultCount[j];
  47. maxInt = j;
  48. }
  49. }
  50. return findResult(childNodes[0], data);
  51. }
  52. }
  53. }
  54. return null;
  55. }
  56. /// <summary>
  57. /// 判断是否还需要分裂
  58. /// </summary>
  59. /// <param name="node"></param>
  60. /// <returns></returns>
  61. public static bool ifEnd(Node node, double entropy,int[] isUsed)
  62. {
  63. try
  64. {
  65. double[] count = node.ClassCount;
  66. int rowCount = node.rowCount;
  67. int maxResult = 0;
  68. #region 数达到某一深度
  69. int deep = node.deep;
  70. if (deep >= maxDeep)
  71. {
  72. maxResult = node.result + 1;
  73. node.feature_Type=("result");
  74. node.features=(new List<String>() { maxResult + "" });
  75. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  76. node.leafNode_Count = 1;
  77. return true;
  78. }
  79. #endregion
  80. #region 纯度(其实跟后面的有点重了,记得要修改)
  81. //maxResult = 1;
  82. //for (int i = 1; i < count.Length; i++)
  83. //{
  84. // if (count[i] / rowCount >= 0.95)
  85. // {
  86. // node.feature_Type=("result");
  87. // node.features=(new List<String> { "" + (i + 1) });
  88. // node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  89. // node.leafNode_Count = 1;
  90. // return true;
  91. // }
  92. //}
  93. #endregion
  94. #region 熵为0
  95. if (entropy == 0)
  96. {
  97. maxResult = node.result+1;
  98. node.feature_Type=("result");
  99. node.features=(new List<String> { maxResult + "" });
  100. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  101. node.leafNode_Count = 1;
  102. return true;
  103. }
  104. #endregion
  105. #region 属性已经分完
  106. bool flag = true;
  107. for (int i = 0; i < isUsed.Length - 1; i++)
  108. {
  109. if (isUsed[i] == 0)
  110. {
  111. flag = false;
  112. break;
  113. }
  114. }
  115. if (flag)
  116. {
  117. maxResult = node.result+1;
  118. node.feature_Type=("result");
  119. node.features=(new List<String> { "" + (maxResult) });
  120. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  121. node.leafNode_Count = 1;
  122. return true;
  123. }
  124. #endregion
  125. #region 数据量少于100
  126. if (rowCount < Limit_Node)
  127. {
  128. maxResult = node.result+1;
  129. node.feature_Type=("result");
  130. node.features=(new List<String> { "" + (maxResult) });
  131. node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
  132. node.leafNode_Count = 1;
  133. return true;
  134. }
  135. #endregion
  136. return false;
  137. }
  138. catch (Exception e)
  139. {
  140. return false;
  141. }
  142. }
  143. #region 排序算法
  144. public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex)
  145. {
  146. for (int i = StartIndex + 1; i <= endIndex; i++)
  147. {
  148. int key = arr[i];
  149. double init = values[i];
  150. int j = i - 1;
  151. while (j >= StartIndex && values[j] > init)
  152. {
  153. arr[j + 1] = arr[j];
  154. values[j + 1] = values[j];
  155. j--;
  156. }
  157. arr[j + 1] = key;
  158. values[j + 1] = init;
  159. }
  160. }
  161. static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
  162. {
  163. int mid = low + ((high - low) >> 1);//计算数组中间的元素的下标
  164. //使用三数取中法选择枢轴
  165. if (values[mid] > values[high])//目标: arr[mid] <= arr[high]
  166. {
  167. swap(values, arr, mid, high);
  168. }
  169. if (values[low] > values[high])//目标: arr[low] <= arr[high]
  170. {
  171. swap(values, arr, low, high);
  172. }
  173. if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]
  174. {
  175. swap(values, arr, mid, low);
  176. }
  177. //此时,arr[mid] <= arr[low] <= arr[high]
  178. return low;
  179. //low的位置上保存这三个位置中间的值
  180. //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了
  181. }
  182. static void swap(double[] values, List<int> arr, int t1, int t2)
  183. {
  184. double temp = values[t1];
  185. values[t1] = values[t2];
  186. values[t2] = temp;
  187. int key = arr[t1];
  188. arr[t1] = arr[t2];
  189. arr[t2] = key;
  190. }
  191. static void QSort(double[] values, List<int> arr, int low, int high)
  192. {
  193. int first = low;
  194. int last = high;
  195. int left = low;
  196. int right = high;
  197. int leftLen = 0;
  198. int rightLen = 0;
  199. if (high - low + 1 < 10)
  200. {
  201. InsertSort(values, arr, low, high);
  202. return;
  203. }
  204. //一次分割
  205. int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三数取中法选择枢轴
  206. double inti = values[key];
  207. int currentKey = arr[key];
  208. while (low < high)
  209. {
  210. while (high > low && values[high] >= inti)
  211. {
  212. if (values[high] == inti)//处理相等元素
  213. {
  214. swap(values, arr, right, high);
  215. right--;
  216. rightLen++;
  217. }
  218. high--;
  219. }
  220. arr[low] = arr[high];
  221. values[low] = values[high];
  222. while (high > low && values[low] <= inti)
  223. {
  224. if (values[low] == inti)
  225. {
  226. swap(values, arr, left, low);
  227. left++;
  228. leftLen++;
  229. }
  230. low++;
  231. }
  232. arr[high] = arr[low];
  233. values[high] = values[low];
  234. }
  235. arr[low] = currentKey;
  236. values[low] = values[key];
  237. //一次快排结束
  238. //把与枢轴key相同的元素移到枢轴最终位置周围
  239. int i = low - 1;
  240. int j = first;
  241. while (j < left && values[i] != inti)
  242. {
  243. swap(values, arr, i, j);
  244. i--;
  245. j++;
  246. }
  247. i = low + 1;
  248. j = last;
  249. while (j > right && values[i] != inti)
  250. {
  251. swap(values, arr, i, j);
  252. i++;
  253. j--;
  254. }
  255. QSort(values, arr, first, low - 1 - leftLen);
  256. QSort(values, arr, low + 1 + rightLen, last);
  257. }
  258. #endregion
  259. /// <summary>
  260. /// 寻找最佳的分裂点
  261. /// </summary>
  262. /// <param name="num"></param>
  263. /// <param name="node"></param>
  264. public static Node findBestSplit(Node node, List<int> nums, int[] isUsed)
  265. {
  266. try
  267. {
  268. //判断是否继续分裂
  269. double totalShang = CalEntropy(node.ClassCount, node.rowCount);
  270. if (ifEnd(node, totalShang,isUsed))
  271. {
  272. return node;
  273. }
  274. #region 变量声明
  275. SplitInfo info = new SplitInfo();
  276. int RowCount = nums.Count; //样本总数
  277. double jubuMax = 0; //局部最大熵
  278. #endregion
  279. for (int i = 0; i < isUsed.Length - 1; i++)
  280. {
  281. if (isUsed[i] == 1)
  282. {
  283. continue;
  284. }
  285. #region 离散变量
  286. if (type[i] == 0)
  287. {
  288. int[] allFeatureCount = new int[0]; //所有类别的数量
  289. double[][] allCount = new double[allNum[i]][];
  290. for (int j = 0; j < allCount.Length; j++)
  291. {
  292. allCount[j] = new double[classCount];
  293. }
  294. int[] countAllFeature = new int[allNum[i]];
  295. List<int>[] temp = new List<int>[allNum[i]];
  296. for (int j = 0; j < temp.Length; j++)
  297. {
  298. temp[j] = new List<int>();
  299. }
  300. for (int j = 0; j < nums.Count; j++)
  301. {
  302. int index = Convert.ToInt32(allData[nums[j]][i]);
  303. temp[index - 1].Add(nums[j]);
  304. countAllFeature[index - 1]++;
  305. allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
  306. }
  307. double allShang = 0;
  308. double chushu = 0;
  309. for (int j = 0; j < allCount.Length; j++)
  310. {
  311. allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
  312. if (countAllFeature[j] > 0)
  313. {
  314. double rate = countAllFeature[j] / Convert.ToDouble(RowCount);
  315. chushu = chushu + rate * Math.Log(rate, 2);
  316. }
  317. }
  318. allShang = (-totalShang + allShang);
  319. if (allShang > jubuMax)
  320. {
  321. info.features = new List<string>();
  322. info.type = 0;
  323. info.temp = temp;
  324. info.splitIndex = i;
  325. info.class_Count = allCount;
  326. jubuMax = allShang;
  327. allFeatureCount = countAllFeature;
  328. }
  329. }
  330. #endregion
  331. #region 连续变量
  332. else
  333. {
  334. double[] leftCount = new double[classCount]; //做节点各个类别的数量
  335. double[] rightCount = new double[classCount]; //右节点各个类别的数量
  336. double[] count1 = new double[classCount]; //子集1的统计量
  337. //double[] count2 = new double[node.getCount().Length]; //子集2的统计量
  338. double[] count2 = new double[node.ClassCount.Length]; //子集2的统计量
  339. for (int j = 0; j < node.ClassCount.Length; j++)
  340. {
  341. count2[j] = node.ClassCount[j];
  342. }
  343. int all1 = 0; //子集1的样本量
  344. int all2 = nums.Count; //子集2的样本量
  345. double lastValue = 0; //上一个记录的类别
  346. double currentValue = 0; //当前类别
  347. double lastPoint = 0; //上一个点的值
  348. double currentPoint = 0; //当前点的值
  349. int splitPoint = 0;
  350. double splitValue = 0;
  351. double[] values = new double[nums.Count];
  352. for (int j = 0; j < values.Length; j++)
  353. {
  354. values[j] = allData[nums[j]][i];
  355. }
  356. QSort(values, nums, 0, nums.Count - 1);
  357. double chushu = 0;
  358. double lianxuMax = 0; //连续型属性的最大熵
  359. for (int j = 0; j < nums.Count - 1; j++)
  360. {
  361. currentValue = allData[nums[j]][lieshu - 1];
  362. currentPoint = allData[nums[j]][i];
  363. if (j == 0)
  364. {
  365. lastValue = currentValue;
  366. lastPoint = currentPoint;
  367. }
  368. if (currentValue != lastValue)
  369. {
  370. double shang1 = CalEntropy(count1, all1);
  371. double shang2 = CalEntropy(count2, all2);
  372. double allShang = shang1 * all1 / (all1 + all2) + shang2 * all2 / (all1 + all2);
  373. allShang = (-totalShang + allShang);
  374. if (lianxuMax < allShang)
  375. {
  376. lianxuMax = allShang;
  377. for (int k = 0; k < count1.Length; k++)
  378. {
  379. leftCount[k] = count1[k];
  380. rightCount[k] = count2[k];
  381. }
  382. splitPoint = j;
  383. splitValue = (currentPoint + lastPoint) / 2;
  384. }
  385. }
  386. all1++;
  387. count1[Convert.ToInt32(currentValue) - 1]++;
  388. count2[Convert.ToInt32(currentValue) - 1]--;
  389. all2--;
  390. lastValue = currentValue;
  391. lastPoint = currentPoint;
  392. }
  393. double rate1 = Convert.ToDouble(leftCount[0] + leftCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
  394. chushu = 0;
  395. if (rate1 > 0)
  396. {
  397. chushu = chushu + rate1 * Math.Log(rate1, 2);
  398. }
  399. double rate2 = Convert.ToDouble(rightCount[0] + rightCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
  400. if (rate2 > 0)
  401. {
  402. chushu = chushu + rate2 * Math.Log(rate2, 2);
  403. }
  404. //lianxuMax = lianxuMax ;
  405. //lianxuMax = lianxuMax;
  406. if (lianxuMax > jubuMax)
  407. {
  408. //info.setSplitIndex(i);
  409. info.splitIndex=(i);
  410. //info.setFeatures(new List<String> { splitValue + "" });
  411. info.features = (new List<String> { splitValue + "" });
  412. //info.setType(1);
  413. info.type=(1);
  414. jubuMax = lianxuMax;
  415. //info.setType(1);
  416. List<int>[] allInt = new List<int>[2];
  417. allInt[0] = new List<int>();
  418. allInt[1] = new List<int>();
  419. for (int k = 0; k < splitPoint; k++)
  420. {
  421. allInt[0].Add(nums[k]);
  422. }
  423. for (int k = splitPoint; k < nums.Count; k++)
  424. {
  425. allInt[1].Add(nums[k]);
  426. }
  427. info.temp=(allInt);
  428. //info.setTemp(allInt);
  429. double[][] alls = new double[2][];
  430. alls[0] = new double[leftCount.Length];
  431. alls[1] = new double[leftCount.Length];
  432. for (int k = 0; k < leftCount.Length; k++)
  433. {
  434. alls[0][k] = leftCount[k];
  435. alls[1][k] = rightCount[k];
  436. }
  437. info.class_Count=(alls);
  438. //info.setclassCount(alls);
  439. }
  440. }
  441. #endregion
  442. }
  443. #region 如果找不到最佳的分裂属性,则设为叶节点
  444. if (info.splitIndex == -1)
  445. {
  446. double[] finalCount = node.ClassCount;
  447. double max = finalCount[0];
  448. int result = 1;
  449. for (int i = 1; i < finalCount.Length; i++)
  450. {
  451. if (finalCount[i] > max)
  452. {
  453. max = finalCount[i];
  454. result = (i + 1);
  455. }
  456. }
  457. node.feature_Type=("result");
  458. node.features=(new List<String> { "" + result });
  459. return node;
  460. }
  461. #endregion
  462. #region 分裂
  463. int deep = node.deep;
  464. node.SplitFeature=("" + info.splitIndex);
  465. List<Node> childNode = new List<Node>();
  466. int[] used = new int[isUsed.Length];
  467. for (int i = 0; i < used.Length; i++)
  468. {
  469. used[i] = isUsed[i];
  470. }
  471. if (info.type == 0)
  472. {
  473. used[info.splitIndex] = 1;
  474. node.feature_Type=("离散");
  475. }
  476. else
  477. {
  478. used[info.splitIndex] = 0;
  479. node.feature_Type=("连续");
  480. }
  481. int sumLeaf = 0;
  482. int sumWrong = 0;
  483. List<int>[] rowIndex = info.temp;
  484. List<String> features = info.features;
  485. for (int j = 0; j < rowIndex.Length; j++)
  486. {
  487. if (rowIndex[j].Count == 0)
  488. {
  489. continue;
  490. }
  491. if (info.type == 0)
  492. features.Add("" + (j + 1));
  493. Node node1 = new Node();
  494. node1.setClassCount(info.class_Count[j]);
  495. node1.deep=(deep + 1);
  496. node1.rowCount = info.temp[j].Count;
  497. node1 = findBestSplit(node1, info.temp[j], used);
  498. sumLeaf += node1.leafNode_Count;
  499. sumWrong += node1.leafWrong;
  500. childNode.Add(node1);
  501. }
  502. node.leafNode_Count = (sumLeaf);
  503. node.leafWrong = (sumWrong);
  504. node.features=(features);
  505. node.childNodes=(childNode);
  506. #endregion
  507. return node;
  508. }
  509. catch (Exception e)
  510. {
  511. Console.WriteLine(e.StackTrace);
  512. return node;
  513. }
  514. }
  515. /// <summary>
  516. /// 计算熵
  517. /// </summary>
  518. /// <param name="counts"></param>
  519. /// <param name="countAll"></param>
  520. /// <returns></returns>
  521. public static double CalEntropy(double[] counts, int countAll)
  522. {
  523. try
  524. {
  525. double allShang = 0;
  526. for (int i = 0; i < counts.Length; i++)
  527. {
  528. if (counts[i] == 0)
  529. {
  530. continue;
  531. }
  532. double rate = counts[i] / countAll;
  533. allShang = allShang + rate * Math.Log(rate, 2);
  534. }
  535. return allShang;
  536. }
  537. catch (Exception e)
  538. {
  539. return 0;
  540. }
  541. }
  542. #region 悲观剪枝
  543. public static void prune(Node node)
  544. {
  545. if (node.feature_Type == "result")
  546. return;
  547. double treeWrong = node.getErrorCount() + 0.5;
  548. double leafError = node.leafWrong + 0.5 * node.leafNode_Count;
  549. double var = Math.Sqrt(leafError * (1 - Convert.ToDouble(leafError) / node.rowCount));
  550. double panbie = leafError + var - treeWrong;
  551. if (panbie > 0)
  552. {
  553. node.feature_Type = "result";
  554. node.childNodes = null;
  555. int result = node.result + 1;
  556. node.features= new List<String>() { "" + result };
  557. }
  558. else
  559. {
  560. List<Node> childNodes = node.childNodes;
  561. for (int i = 0; i < childNodes.Count; i++)
  562. {
  563. prune(childNodes[i]);
  564. }
  565. }
  566. }
  567. #endregion
  568. #endregion

总结

要记住,C4.5是分类树最终要的算法,算法的思想其实很简单,但是分类的准确性高。可以说C4.5是ID3的升级版和强化版,解决了ID3未能解决的问题。要重点记住以下几个方面:

1、C4.5是采用信息增益率选择分裂的属性,解决了ID3选择属性时的偏向性问题;
2、C4.5能够对连续数据进行处理,采用一刀切的方式将连续型的数据切成两份,在选择切割点的时候使用信息增益作为择优的条件;
3、C4.5采用悲观剪枝的策略,一定程度上降低了过拟合的影响。