初识ID3

回顾决策树的基本知识,其构建过程主要有下述三个重要的问题:
(1)数据是怎么分裂的
(2)如何选择分类的属性
(3)什么时候停止分裂

从上述三个问题出发,以实际的例子对ID3算法进行阐述。

例:通过当天的天气、温度、湿度和季节预测明天的天气

表1 原始数据**

当天天气 温度 湿度 季节 明天天气
25 50 春天
21 48 春天
18 70 春天
28 41 夏天
8 65 冬天
18 43 夏天
24 56 秋天
18 76 秋天
31 61 夏天
6 43 冬天
15 55 秋天
4 58 冬天

**

1.数据分割

对于离散型数据,直接按照离散数据的取值进行分裂,每一个取值对应一个子节点,以“当前天气”为例对数据进行分割,如图1所示。
决策树系列(三)——ID3 - 图1

对于连续型数据,ID3原本是没有处理能力的,只有通过离散化将连续性数据转化成离散型数据再进行处理。

连续数据离散化是另外一个课题,本文不深入阐述,这里直接采用等距离数据划分的李算话方法。该方法先对数据进行排序,然后将连续型数据划分为多个区间,并使每一个区间的数据量基本相同,以温度为例对数据进行分割,如图2所示。
决策树系列(三)——ID3 - 图2

2. 选择最优分裂属性

  1. <br />ID3采用信息增益作为选择最优的分裂属性的方法,选择熵作为衡量节点纯度的标准,信息增益的计算公式如下:<br /> ![](https://cdn.nlark.com/yuque/0/2020/png/709883/1586512818008-a75a8ce6-81d5-41e0-9e3f-07a169eb6f1d.png#align=left&display=inline&height=47&originHeight=47&originWidth=392&size=0&status=done&style=none&width=392)

其中, 决策树系列(三)——ID3 - 图3表示父节点的熵; 决策树系列(三)——ID3 - 图4表示节点i的熵,熵越大,节点的信息量越多,越不纯; 决策树系列(三)——ID3 - 图5表示子节点i的数据量与父节点数据量之比。 决策树系列(三)——ID3 - 图6越大,表示分裂后的熵越小,子节点变得越纯,分类的效果越好,因此选择 决策树系列(三)——ID3 - 图7最大的属性作为分裂属性。

对上述的例子的跟节点进行分裂,分别计算每一个属性的信息增益,选择信息增益最大的属性进行分裂。

天气属性:(数据分割如上图1所示)
  决策树系列(三)——ID3 - 图8

温度:(数据分割如上图2所示)
决策树系列(三)——ID3 - 图9

湿度:
决策树系列(三)——ID3 - 图10
决策树系列(三)——ID3 - 图11

季节:
决策树系列(三)——ID3 - 图12
决策树系列(三)——ID3 - 图13
由于决策树系列(三)——ID3 - 图14最大,所以选择属性“季节”作为根节点的分裂属性。

3.停止分裂的条件

  1. <br />停止分裂的条件已经在**决策树**中阐述,这里不再进行阐述。<br />

(1)最小节点数

当节点的数据量小于一个指定的数量时,不继续分裂。两个原因:一是数据量较少时,再做分裂容易强化噪声数据的作用;二是降低树生长的复杂性。提前结束分裂一定程度上有利于降低过拟合的影响。
  

(2)熵或者基尼值小于阀值。

  1. <br />由上述可知,熵和基尼值的大小表示数据的复杂程度,当熵或者基尼值过小时,表示数据的纯度比较大,如果熵或者基尼值小于一定程度时,节点停止分裂。<br />  

(3)决策树的深度达到指定的条件

 
节点的深度可以理解为节点与决策树跟节点的距离,如根节点的子节点的深度为1,因为这些节点与跟节点的距离为1,子节点的深度要比父节点的深度大1。决策树的深度是所有叶子节点的最大深度,当深度到达指定的上限大小时,停止分裂。
  

(4)所有特征已经使用完毕,不能继续进行分裂

  1. <br />被动式停止分裂的条件,当已经没有可分的属性时,直接将当前节点设置为叶子节点。

程序设计及源代码(C#版本)

**

(1)数据处理

用二维数组存储原始的数据,每一行表示一条记录,前n-1列表示数据的属性,第n列表示分类的标签。

  1. static double[][] allData;

为了方便后面的处理,对离散属性进行数字化处理,将离散值表示成数字,并用一个链表数组进行存储,数组的第一个元素表示属性1的离散值。

  1. static List<String>[] featureValues;

那么经过处理后的表1数据可以转化为如表2所示的数据:

表2 初始化后的数据

当天天气 温度 湿度 季节 明天天气
1 25 50 1 1
2 21 48 1 2
2 18 70 1 3
1 28 41 2 1
3 8 65 3 2
1 18 43 2 1
2 24 56 4 1
3 18 76 4 2
3 31 61 2 1
2 6 43 3 3
1 15 55 4 2
3 4 58 3 3

其中,对于当天天气属性,数字{1,2,3}分别表示{晴,阴,雨};对于季节属性{1,2,3,4}分别表示{春天、夏天、冬天、秋天};对于明天天气{1,2,3}分别表示{晴、阴、雨}。
**

(2)两个类:节点类和分裂信息

  

a)节点类Node

  1. <br />该类存储了节点的信息,包括节点的数据量、节点选择的分裂属性、节点输出类、子节点的个数、子节点的分类误差等。
  1. class Node
  2. {
  3. /// <summary>
  4. /// 各个子节点的取值
  5. /// </summary>
  6. public List<String> features { get; set; }
  7. /// <summary>
  8. /// 分裂属性的类型
  9. /// </summary>
  10. public String feature_Type { get; set; }
  11. /// <summary>
  12. /// 分裂的属性
  13. /// </summary>
  14. public String SplitFeature { get; set; }
  15. /// <summary>
  16. /// 节点对应各个分类的数目
  17. /// </summary>
  18. public double[] ClassCount { get; set; }
  19. /// <summary>
  20. /// 各个孩子节点
  21. /// </summary>
  22. public List<Node> childNodes { get; set; }
  23. /// <summary>
  24. /// 父亲节点(未用到)
  25. /// </summary>
  26. public Node Parent { get; set; }
  27. /// <summary>
  28. /// 占比最大的类别
  29. /// </summary>
  30. public String finalResult { get; set; }
  31. /// <summary>
  32. /// 数的深度
  33. /// </summary>
  34. public int deep { get; set; }
  35. /// <summary>
  36. /// 该节点占比最大的类标号
  37. /// </summary>
  38. public int result { get; set; }
  39. /// <summary>
  40. /// 节点的数量
  41. /// </summary>
  42. public int rowCount{ get; set; }
  43. public void setClassCount(double[] count)
  44. {
  45. this.ClassCount = count;
  46. double max = ClassCount[0];
  47. int result = 0;
  48. for (int i = 1; i < ClassCount.Length; i++)
  49. {
  50. if (max < ClassCount[i])
  51. {
  52. max = ClassCount[i];
  53. result = i;
  54. }
  55. }
  56. //wrong = Convert.ToInt32(nums.Count - ClassCount[result]);
  57. this.result = result;
  58. }
  59. }

b)分裂信息类SplitInfo

  1. <br />该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。
  1. class SplitInfo
  2. {
  3. /// <summary>
  4. /// 分裂的列下标
  5. /// </summary>
  6. public int splitIndex { get; set; }
  7. /// <summary>
  8. /// 数据类型
  9. /// </summary>
  10. public int type { get; set; }
  11. /// <summary>
  12. /// 分裂属性的取值
  13. /// </summary>
  14. public List<String> features { get; set; }
  15. /// <summary>
  16. /// 各个节点的行坐标链表
  17. /// </summary>
  18. public List<int>[] temp { get; set; }
  19. /// <summary>
  20. /// 每个节点各类的数目
  21. /// </summary>
  22. public double[][] class_Count { get; set; }
  23. }

**

(3)节点分裂方法findBestSplit(Node node,List nums,int[] isUsed),该方法对节点进行分裂,返回值Node

其中:

  • node表示即将进行分裂的节点;
  • nums表示节点数据对应的行坐标列表;
  • isUsed表示到该节点位置所有属性的使用情况(1:表示该属性不能再次使用,0:表示该属性可以使用);

findBestSplit主要有以下几个组成部分:

1)节点分裂停止的判定

判断节点是否需要继续分裂,分裂判断条件如上文所述。源代码如下

  1. public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
  2. {
  3. try
  4. {
  5. double[] count = node.ClassCount;
  6. int rowCount = node.rowCount;
  7. int maxResult = 0;
  8. double maxRate = 0;
  9. #region 数达到某一深度
  10. int deep = node.deep;
  11. if (deep >= maxDeep)
  12. {
  13. maxResult = node.result + 1;
  14. node.feature_Type=("result");
  15. node.features=(new List<String>() { maxResult + "" });
  16. return new Object[] { true, node };
  17. }
  18. #endregion
  19. #region 纯度(其实跟后面的有点重了,记得要修改)
  20. //maxResult = 1;
  21. //for (int i = 1; i < count.Length; i++)
  22. //{
  23. // if (count[i] / rowCount >= 0.95)
  24. // {
  25. // node.setFeatureType("result");
  26. // node.setFeatures(new List<String> { "" + (i + 1) });
  27. // return new Object[] { true, node };
  28. // }
  29. //}
  30. //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
  31. #endregion
  32. #region 熵为0
  33. if (entropy == 0)
  34. {
  35. maxRate = count[0] / rowCount;
  36. maxResult = 1;
  37. for (int i = 1; i < count.Length; i++)
  38. {
  39. if (count[i] / rowCount >= maxRate)
  40. {
  41. maxRate = count[i] / rowCount;
  42. maxResult = i + 1;
  43. }
  44. }
  45. node.feature_Type=("result");
  46. node.features=(new List<String> { maxResult + "" });
  47. return new Object[] { true, node };
  48. }
  49. #endregion
  50. #region 属性已经分完
  51. //int[] isUsed = node.;
  52. bool flag = true;
  53. for (int i = 0; i < isUsed.Length - 1; i++)
  54. {
  55. if (isUsed[i] == 0)
  56. {
  57. flag = false;
  58. break;
  59. }
  60. }
  61. if (flag)
  62. {
  63. maxRate = count[0] / rowCount;
  64. maxResult = 1;
  65. for (int i = 1; i < count.Length; i++)
  66. {
  67. if (count[i] / rowCount >= maxRate)
  68. {
  69. maxRate = count[i] / rowCount;
  70. maxResult = i + 1;
  71. }
  72. }
  73. node.feature_Type=("result");
  74. node.features=(new List<String> { "" + (maxResult) });
  75. return new Object[] { true, node };
  76. }
  77. #endregion
  78. #region 数据量少于100
  79. if (rowCount < Limit_Node)
  80. {
  81. maxRate = count[0] / rowCount;
  82. maxResult = 1;
  83. for (int i = 1; i < count.Length; i++)
  84. {
  85. if (count[i] / rowCount >= maxRate)
  86. {
  87. maxRate = count[i] / rowCount;
  88. maxResult = i + 1;
  89. }
  90. }
  91. node.feature_Type=("result");
  92. node.features=(new List<String> { "" + (maxResult) });
  93. return new Object[] { true, node };
  94. }
  95. #endregion
  96. return new Object[] { false, node };
  97. }
  98. catch (Exception e)
  99. {
  100. return new Object[] { false, node };
  101. }
  102. }

2)寻找最优的分裂属性

寻找最优的分裂属性需要计算每一个分裂属性分裂后的信息增益,计算公式上文已给出,其中熵的计算代码如下:

  1. public static double CalEntropy(double[] counts, int countAll)
  2. {
  3. try
  4. {
  5. double allShang = 0;
  6. for (int i = 0; i < counts.Length; i++)
  7. {
  8. if (counts[i] == 0)
  9. {
  10. continue;
  11. }
  12. double rate = counts[i] / countAll;
  13. allShang = allShang + rate * Math.Log(rate, 2);
  14. }
  15. return -allShang;
  16. }
  17. catch (Exception e)
  18. {
  19. return 0;
  20. }
  21. }

3)进行分裂,同时子节点也执行相同的分类步骤

其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。
全部源代码:

  1. #region ID3核心算法
  2. /// <summary>
  3. /// 测试
  4. /// </summary>
  5. /// <param name="node"></param>
  6. /// <param name="data"></param>
  7. public static String findResult(Node node, String[] data)
  8. {
  9. List<String> featrues = node.features;
  10. String type = node.feature_Type;
  11. if (type == "result")
  12. {
  13. return featrues[0];
  14. }
  15. int split = Convert.ToInt32(node.SplitFeature);
  16. List<Node> childNodes = node.childNodes;
  17. double[] resultCount = node.ClassCount;
  18. if (type == "连续")
  19. {
  20. for (int i = 0; i < featrues.Count; i++)
  21. {
  22. double value = Convert.ToDouble(featrues[i]);
  23. if (Convert.ToDouble(data[split]) <= value)
  24. {
  25. return findResult(childNodes[i], data);
  26. }
  27. }
  28. return findResult(childNodes[featrues.Count], data);
  29. }
  30. else
  31. {
  32. for (int i = 0; i < featrues.Count; i++)
  33. {
  34. if (data[split] == featrues[i])
  35. {
  36. return findResult(childNodes[i], data);
  37. }
  38. if (i == featrues.Count - 1)
  39. {
  40. double count = resultCount[0];
  41. int maxInt = 0;
  42. for (int j = 1; j < resultCount.Length; j++)
  43. {
  44. if (count < resultCount[j])
  45. {
  46. count = resultCount[j];
  47. maxInt = j;
  48. }
  49. }
  50. return findResult(childNodes[0], data);
  51. }
  52. }
  53. }
  54. return null;
  55. }
  56. /// <summary>
  57. /// 判断是否还需要分裂
  58. /// </summary>
  59. /// <param name="node"></param>
  60. /// <returns></returns>
  61. public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
  62. {
  63. try
  64. {
  65. double[] count = node.ClassCount;
  66. int rowCount = node.rowCount;
  67. int maxResult = 0;
  68. double maxRate = 0;
  69. #region 数达到某一深度
  70. int deep = node.deep;
  71. if (deep >= maxDeep)
  72. {
  73. maxResult = node.result + 1;
  74. node.feature_Type=("result");
  75. node.features=(new List<String>() { maxResult + "" });
  76. return new Object[] { true, node };
  77. }
  78. #endregion
  79. #region 纯度(其实跟后面的有点重了,记得要修改)
  80. //maxResult = 1;
  81. //for (int i = 1; i < count.Length; i++)
  82. //{
  83. // if (count[i] / rowCount >= 0.95)
  84. // {
  85. // node.setFeatureType("result");
  86. // node.setFeatures(new List<String> { "" + (i + 1) });
  87. // return new Object[] { true, node };
  88. // }
  89. //}
  90. //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
  91. #endregion
  92. #region 熵为0
  93. if (entropy == 0)
  94. {
  95. maxRate = count[0] / rowCount;
  96. maxResult = 1;
  97. for (int i = 1; i < count.Length; i++)
  98. {
  99. if (count[i] / rowCount >= maxRate)
  100. {
  101. maxRate = count[i] / rowCount;
  102. maxResult = i + 1;
  103. }
  104. }
  105. node.feature_Type=("result");
  106. node.features=(new List<String> { maxResult + "" });
  107. return new Object[] { true, node };
  108. }
  109. #endregion
  110. #region 属性已经分完
  111. //int[] isUsed = node.;
  112. bool flag = true;
  113. for (int i = 0; i < isUsed.Length - 1; i++)
  114. {
  115. if (isUsed[i] == 0)
  116. {
  117. flag = false;
  118. break;
  119. }
  120. }
  121. if (flag)
  122. {
  123. maxRate = count[0] / rowCount;
  124. maxResult = 1;
  125. for (int i = 1; i < count.Length; i++)
  126. {
  127. if (count[i] / rowCount >= maxRate)
  128. {
  129. maxRate = count[i] / rowCount;
  130. maxResult = i + 1;
  131. }
  132. }
  133. node.feature_Type=("result");
  134. node.features=(new List<String> { "" + (maxResult) });
  135. return new Object[] { true, node };
  136. }
  137. #endregion
  138. #region 数据量少于100
  139. if (rowCount < Limit_Node)
  140. {
  141. maxRate = count[0] / rowCount;
  142. maxResult = 1;
  143. for (int i = 1; i < count.Length; i++)
  144. {
  145. if (count[i] / rowCount >= maxRate)
  146. {
  147. maxRate = count[i] / rowCount;
  148. maxResult = i + 1;
  149. }
  150. }
  151. node.feature_Type=("result");
  152. node.features=(new List<String> { "" + (maxResult) });
  153. return new Object[] { true, node };
  154. }
  155. #endregion
  156. return new Object[] { false, node };
  157. }
  158. catch (Exception e)
  159. {
  160. return new Object[] { false, node };
  161. }
  162. }
  163. #region 排序算法
  164. public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex)
  165. {
  166. for (int i = StartIndex + 1; i <= endIndex; i++)
  167. {
  168. int key = arr[i];
  169. double init = values[i];
  170. int j = i - 1;
  171. while (j >= StartIndex && values[j] > init)
  172. {
  173. arr[j + 1] = arr[j];
  174. values[j + 1] = values[j];
  175. j--;
  176. }
  177. arr[j + 1] = key;
  178. values[j + 1] = init;
  179. }
  180. }
  181. static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
  182. {
  183. int mid = low + ((high - low) >> 1);//计算数组中间的元素的下标
  184. //使用三数取中法选择枢轴
  185. if (values[mid] > values[high])//目标: arr[mid] <= arr[high]
  186. {
  187. swap(values, arr, mid, high);
  188. }
  189. if (values[low] > values[high])//目标: arr[low] <= arr[high]
  190. {
  191. swap(values, arr, low, high);
  192. }
  193. if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]
  194. {
  195. swap(values, arr, mid, low);
  196. }
  197. //此时,arr[mid] <= arr[low] <= arr[high]
  198. return low;
  199. //low的位置上保存这三个位置中间的值
  200. //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了
  201. }
  202. static void swap(double[] values, List<int> arr, int t1, int t2)
  203. {
  204. double temp = values[t1];
  205. values[t1] = values[t2];
  206. values[t2] = temp;
  207. int key = arr[t1];
  208. arr[t1] = arr[t2];
  209. arr[t2] = key;
  210. }
  211. static void QSort(double[] values, List<int> arr, int low, int high)
  212. {
  213. int first = low;
  214. int last = high;
  215. int left = low;
  216. int right = high;
  217. int leftLen = 0;
  218. int rightLen = 0;
  219. if (high - low + 1 < 10)
  220. {
  221. InsertSort(values, arr, low, high);
  222. return;
  223. }
  224. //一次分割
  225. int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三数取中法选择枢轴
  226. double inti = values[key];
  227. int currentKey = arr[key];
  228. while (low < high)
  229. {
  230. while (high > low && values[high] >= inti)
  231. {
  232. if (values[high] == inti)//处理相等元素
  233. {
  234. swap(values, arr, right, high);
  235. right--;
  236. rightLen++;
  237. }
  238. high--;
  239. }
  240. arr[low] = arr[high];
  241. values[low] = values[high];
  242. while (high > low && values[low] <= inti)
  243. {
  244. if (values[low] == inti)
  245. {
  246. swap(values, arr, left, low);
  247. left++;
  248. leftLen++;
  249. }
  250. low++;
  251. }
  252. arr[high] = arr[low];
  253. values[high] = values[low];
  254. }
  255. arr[low] = currentKey;
  256. values[low] = values[key];
  257. //一次快排结束
  258. //把与枢轴key相同的元素移到枢轴最终位置周围
  259. int i = low - 1;
  260. int j = first;
  261. while (j < left && values[i] != inti)
  262. {
  263. swap(values, arr, i, j);
  264. i--;
  265. j++;
  266. }
  267. i = low + 1;
  268. j = last;
  269. while (j > right && values[i] != inti)
  270. {
  271. swap(values, arr, i, j);
  272. i++;
  273. j--;
  274. }
  275. QSort(values, arr, first, low - 1 - leftLen);
  276. QSort(values, arr, low + 1 + rightLen, last);
  277. }
  278. #endregion
  279. /// <summary>
  280. /// 寻找最佳的分裂点
  281. /// </summary>
  282. /// <param name="num"></param>
  283. /// <param name="node"></param>
  284. public static Node findBestSplit(Node node, int lastCol,List<int> nums,int[] isUsed)
  285. {
  286. try
  287. {
  288. //判断是否继续分裂
  289. double totalShang = CalEntropy(node.ClassCount, nums.Count);
  290. Object[] check = ifEnd(node, totalShang, isUsed);
  291. if ((bool)check[0])
  292. {
  293. node = (Node)check[1];
  294. return node;
  295. }
  296. #region 变量声明
  297. SplitInfo info = new SplitInfo();
  298. //int[] isUsed = node.getUsed(); //连续变量or离散变量
  299. //List<int> nums = node.getNum(); //样本的标号
  300. int RowCount = nums.Count; //样本总数
  301. double jubuMax = 0; //局部最大熵
  302. #endregion
  303. for (int i = 0; i < isUsed.Length - 1; i++)
  304. {
  305. if (isUsed[i] == 1)
  306. {
  307. continue;
  308. }
  309. #region 离散变量
  310. if (type[i] == 0)
  311. {
  312. int[] allFeatureCount = new int[0]; //所有类别的数量
  313. double[][] allCount = new double[allNum[i]][];
  314. for (int j = 0; j < allCount.Length; j++)
  315. {
  316. allCount[j] = new double[classCount];
  317. }
  318. int[] countAllFeature = new int[allNum[i]];
  319. List<int>[] temp = new List<int>[allNum[i]];
  320. for (int j = 0; j < temp.Length; j++)
  321. {
  322. temp[j] = new List<int>();
  323. }
  324. for (int j = 0; j < nums.Count; j++)
  325. {
  326. int index = Convert.ToInt32(allData[nums[j]][i]);
  327. temp[index - 1].Add(nums[j]);
  328. countAllFeature[index - 1]++;
  329. allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
  330. }
  331. double allShang = 0;
  332. for (int j = 0; j < allCount.Length; j++)
  333. {
  334. allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
  335. }
  336. allShang = (totalShang - allShang);
  337. if (allShang > jubuMax)
  338. {
  339. info.features=new List<String>();
  340. info.type=0;
  341. info.temp=(temp);
  342. info.splitIndex=(i);
  343. info.class_Count=(allCount);
  344. jubuMax = allShang;
  345. allFeatureCount = countAllFeature;
  346. }
  347. }
  348. #endregion
  349. #region 连续变量
  350. else
  351. {
  352. double[] leftCount = new double[classCount]; //做节点各个类别的数量
  353. double[] rightCount = new double[classCount]; //右节点各个类别的数量
  354. double[] values = new double[nums.Count];
  355. List<String> List_Feature = new List<string>();
  356. for (int j = 0; j < values.Length; j++)
  357. {
  358. values[j] = allData[nums[j]][i];
  359. }
  360. QSort(values, nums, 0, nums.Count - 1);
  361. int eachNum = nums.Count / 5;
  362. double lianxuMax = 0; //连续型属性的最大熵
  363. int index = 1;
  364. double[][] counts = new double[5][];
  365. List<int>[] temp = new List<int>[5];
  366. for (int j = 0; j < 5; j++)
  367. {
  368. counts[j] = new double[classCount];
  369. temp[j] = new List<int>();
  370. }
  371. for (int j = 0; j < nums.Count - 1; j++)
  372. {
  373. if (j >= index * eachNum&&index<5)
  374. {
  375. List_Feature.Add(allData[nums[j]][i]+"");
  376. lianxuMax += eachNum*CalEntropy(counts[index - 1], eachNum)/RowCount;
  377. index++;
  378. }
  379. temp[index-1].Add(nums[j]);
  380. counts[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1])-1]++;
  381. }
  382. lianxuMax += ((eachNum + nums.Count % 5)*CalEntropy(counts[index - 1], eachNum + nums.Count % 5) / RowCount);
  383. lianxuMax = totalShang - lianxuMax;
  384. if (lianxuMax > jubuMax)
  385. {
  386. info.splitIndex=(i);
  387. info.features=(List_Feature);
  388. info.type=(1);
  389. jubuMax = lianxuMax;
  390. info.temp=(temp);
  391. info.class_Count=(counts);
  392. }
  393. }
  394. #endregion
  395. }
  396. #region 如何找不到最佳的分裂属性,则设为叶节点
  397. if (info.splitIndex == -1)
  398. {
  399. double[] finalCount = node.ClassCount;
  400. double max = finalCount[0];
  401. int result = 1;
  402. for (int i = 1; i < finalCount.Length; i++)
  403. {
  404. if (finalCount[i] > max)
  405. {
  406. max = finalCount[i];
  407. result = (i + 1);
  408. }
  409. }
  410. node.feature_Type=("result");
  411. node.features=(new List<String> { "" + result });
  412. return node;
  413. }
  414. #endregion
  415. int deep = node.deep;
  416. #region 分裂
  417. node.SplitFeature=("" + info.splitIndex);
  418. List<Node> childNode = new List<Node>();
  419. int[] used = new int[isUsed.Length];
  420. for (int i = 0; i < used.Length; i++)
  421. {
  422. used[i] = isUsed[i];
  423. }
  424. if (info.type == 0)
  425. {
  426. used[info.splitIndex] = 1;
  427. node.feature_Type=("离散");
  428. }
  429. else
  430. {
  431. used[info.splitIndex] = 0;
  432. node.feature_Type=("连续");
  433. }
  434. int sumLeaf = 0;
  435. int sumWrong = 0;
  436. List<int>[] rowIndex = info.temp;
  437. List<String> features = info.features;
  438. for (int j = 0; j < rowIndex.Length; j++)
  439. {
  440. if (rowIndex[j].Count == 0)
  441. {
  442. continue;
  443. }
  444. if (info.type == 0)
  445. features.Add(""+(j+1));
  446. Node node1 = new Node();
  447. //node1.setNum(info.getTemp()[j]);
  448. node1.setClassCount(info.class_Count[j]);
  449. //node1.setUsed(used);
  450. node1.deep=(deep + 1);
  451. node1.rowCount = info.temp[j].Count;
  452. node1 = findBestSplit(node1, info.splitIndex,info.temp[j], used);
  453. childNode.Add(node1);
  454. }
  455. node.features=(features);
  456. node.childNodes=(childNode);
  457. #endregion
  458. return node;
  459. }
  460. catch (Exception e)
  461. {
  462. Console.WriteLine(e.StackTrace);
  463. return node;
  464. }
  465. }
  466. /// <summary>
  467. /// 计算熵
  468. /// </summary>
  469. /// <param name="counts"></param>
  470. /// <param name="countAll"></param>
  471. /// <returns></returns>
  472. public static double CalEntropy(double[] counts, int countAll)
  473. {
  474. try
  475. {
  476. double allShang = 0;
  477. for (int i = 0; i < counts.Length; i++)
  478. {
  479. if (counts[i] == 0)
  480. {
  481. continue;
  482. }
  483. double rate = counts[i] / countAll;
  484. allShang = allShang + rate * Math.Log(rate, 2);
  485. }
  486. return -allShang;
  487. }
  488. catch (Exception e)
  489. {
  490. return 0;
  491. }
  492. }
  493. #endregion

(注:上述代码只是ID3的核心代码,数据预处理的代码并没有给出,只要将预处理后的数据输入到主方法findBestSplit中,就可以得到最终的结果)

总结

  1. <br />ID3是基本的决策树构建算法,作为决策树经典的构建算法,其具有结构简单、清晰易懂的特点。虽然ID3比较灵活方便,但是有以下几个缺点:<br /> (1)采用信息增益进行分裂,分裂的精确度可能没有采用信息增益率进行分裂高<br /> (2)不能处理连续型数据,只能通过离散化将连续性数据转化为离散型数据<br /> (3)不能处理缺省值<br /> (4)没有对决策树进行剪枝处理,很可能会出现过拟合的问题