1. import pandas as pd
    2. from math import log
    3. import treePlotter
    4. import operator
    5. #计算给定数据集的香农熵
    6. def calcShannonEnt(dataSet):
    7. numEntries = len(dataSet)
    8. labelCounts = {}
    9. for featVec in dataSet:
    10. currentLabel = featVec[-1]
    11. if currentLabel not in labelCounts.keys():
    12. labelCounts[currentLabel] = 0
    13. labelCounts[currentLabel] += 1
    14. shannonEnt = 0.0
    15. for key in labelCounts:
    16. prob = float(labelCounts[key])/numEntries
    17. shannonEnt -= prob * log(prob, 2)
    18. return shannonEnt
    19. # 对离散变量划分数据集,取出该特征取值为value的所有样本
    20. def splitDataSet(dataSet, axis, value):
    21. retDataSet = []
    22. for featVec in dataSet:
    23. if featVec[axis] == value:
    24. reducedFeatVec = featVec[:axis]
    25. reducedFeatVec.extend(featVec[axis + 1:])
    26. retDataSet.append(reducedFeatVec)
    27. return retDataSet
    28. # 对连续变量划分数据集——二分法。不大于或者大于value的样本分别保存,进行划分
    29. # direction规定划分的方向,决定是划分出小于value的数据样本还是大于value的数据样本集
    30. def splitContinuousDataSet(dataSet, axis, value, direction):
    31. retDataSet = []
    32. for featVec in dataSet:
    33. if direction == 0:
    34. if featVec[axis] > value:
    35. retDataSet.append(featVec)
    36. else:
    37. if featVec[axis] <= value:
    38. retDataSet.append(featVec)
    39. return retDataSet
    40. # 选择最好的数据集划分方式
    41. def chooseBestFeatureToSplit(dataSet, labels):
    42. numFeatures = len(dataSet[0]) - 1
    43. baseEntropy = calcShannonEnt(dataSet)
    44. bestInfoGain = 0.0
    45. bestFeature = -1
    46. bestSplitDict = {}
    47. for i in range(numFeatures):
    48. featList = [example[i] for example in dataSet]
    49. # 对连续型特征进行处理
    50. if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
    51. sortfeatList = sorted(featList) # 二分法:先对属性值从小到大进行排序
    52. splitList = []
    53. for j in range(len(sortfeatList) - 1):
    54. splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)
    55. bestSplitEntropy = 10000
    56. slen = len(splitList)
    57. for j in range(slen):
    58. value = splitList[j]
    59. newEntropy = 0.0
    60. subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
    61. subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
    62. prob0 = len(subDataSet0) / float(len(dataSet))
    63. prob1 = len(subDataSet1) / float(len(dataSet))
    64. newEntropy = prob0 * calcShannonEnt(subDataSet0) \
    65. + prob1 * calcShannonEnt(subDataSet1)
    66. if newEntropy < bestSplitEntropy:
    67. bestSplitEntropy = newEntropy
    68. bestSplit = j
    69. bestSplitDict[labels[i]] = splitList[bestSplit]
    70. infoGain = baseEntropy - bestSplitEntropy # 当前连续型特征最优划分点的信息增益
    71. # 对离散型特征进行处理
    72. else:
    73. uniqueVals = set(featList)
    74. newEntropy = 0.0
    75. for value in uniqueVals:
    76. subDataSet = splitDataSet(dataSet, i, value)
    77. prob = len(subDataSet) / float(len(dataSet))
    78. newEntropy += prob * calcShannonEnt(subDataSet)
    79. infoGain = baseEntropy - newEntropy
    80. # 比较每一个特征的信息增益,选择最大的。
    81. if infoGain > bestInfoGain:
    82. bestInfoGain = infoGain
    83. bestFeature = i
    84. # 返回:最优划分特征所在的列,和连续型{特征标签:最优划分点}
    85. return bestFeature, bestSplitDict
    86. # 采用多数表决(投票)的方法决定该叶子结点的分类。
    87. def majorityCnt(classList):
    88. classCount = {}
    89. for vote in classList:
    90. if vote not in classCount.keys():
    91. classCount[vote] = 0
    92. classCount[vote] += 1
    93. sortedClassCount = sorted(classCount.items(), \
    94. key=operator.itemgetter(1), reverse=True)
    95. return sortedClassCount[0][0]
    96. # 主程序,递归产生决策树
    97. def createTree(dataSet, labels, data_full, labels_full):
    98. classList = [example[-1] for example in dataSet]
    99. if classList.count(classList[0]) == len(classList):
    100. # 一、如果样本都属于同一类,就没必要再分类了。
    101. return classList[0]
    102. if len(dataSet[0]) == 1:
    103. # 二、所有特征都用完了,但类别标签仍然不是唯一的,返回出现次数最多的类别
    104. return majorityCnt(classList)
    105. bestFeat, bestSplitDict = chooseBestFeatureToSplit(dataSet, labels)
    106. if bestFeat == -1:
    107. # 三、没有选出最优划分特征,返回出现次数最多的类别
    108. return majorityCnt(classList)
    109. bestFeatLabel = labels[bestFeat]
    110. # 1、最优划分特征是离散型
    111. if type(dataSet[0][bestFeat]).__name__ == 'str':
    112. myTree = {bestFeatLabel: {}}
    113. featValues = [example[bestFeat] for example in dataSet]
    114. uniqueVals = set(featValues)
    115. bestFeatIndexInFull = labels_full.index(bestFeatLabel)
    116. featValuesFull = [example[bestFeatIndexInFull] \
    117. for example in data_full]
    118. uniqueValsFull = set(featValuesFull)
    119. del (labels[bestFeat])
    120. for value in uniqueValsFull:
    121. if value in uniqueVals:
    122. subLabels = labels[:]
    123. valueDataSet = splitDataSet(dataSet, bestFeat, value)
    124. myTree[bestFeatLabel][value] = createTree(valueDataSet, \
    125. subLabels, data_full, labels_full)
    126. else:
    127. myTree[bestFeatLabel][value] = majorityCnt(classList)
    128. # 2、连续型特征,不删除
    129. else:
    130. bestSplitValue = bestSplitDict[bestFeatLabel]
    131. bestFeatLabel = labels[bestFeat] + '<=' + str(bestSplitValue)
    132. myTree = {bestFeatLabel: {}}
    133. subDataSet0 = splitContinuousDataSet(dataSet, bestFeat, bestSplitValue, 0) # >value
    134. subDataSet1 = splitContinuousDataSet(dataSet, bestFeat, bestSplitValue, 1) # <=value
    135. myTree[bestFeatLabel]['no'] = createTree(subDataSet0, labels, \
    136. data_full, labels_full)
    137. myTree[bestFeatLabel]['yes'] = createTree(subDataSet1, labels, \
    138. data_full, labels_full)
    139. return myTree
    140. #决策树
    141. def classify(inputTree, featLabels, testVec):
    142. firstStr = list(inputTree.keys())[0]
    143. secondDict = inputTree[firstStr]
    144. featIndex = featLabels.index(firstStr)
    145. for key in secondDict.keys():
    146. if testVec[featIndex] == key:
    147. if type(secondDict[key]).__name__ == 'dict':
    148. classLabel = classify(secondDict[key], featLabels, testVec)
    149. else:
    150. classLabel = secondDict[key]
    151. return classLabel
    152. #决策树
    153. def classifyAll(inputTree, featLabels, testDataSet):
    154. classLabelAll = []
    155. for testVec in testDataSet:
    156. classLabelAll.append(classify(inputTree, featLabels, testVec))
    157. return classLabelAll
    158. if __name__ == '__main__':
    159. file = pd.read_csv('data.csv')
    160. data = file.values[:, 3:].tolist()
    161. for i in [t[0:-1] for t in data]: #连续型变量转化为浮点型
    162. for j in i:
    163. j = float(j)
    164. label = file.columns.values[3:].tolist()
    165. label_full = label
    166. train_data = data[1:200]
    167. train_data_full = train_data
    168. test_data = data[200:245]
    169. desicionTree = createTree(train_data, label, train_data_full, label_full)
    170. print('desicionTree:\n', desicionTree)
    171. treePlotter.createPlot(desicionTree)
    172. print('classifyResult:\n', classifyAll(desicionTree, label, test_data))