1. #适用范围:全离散型变量
    2. # -*- coding: utf-8 -*-
    3. __author__ = 'Wsine'
    4. from math import log
    5. import operator
    6. import treePlotter
    7. def calcShannonEnt(dataSet):
    8. """
    9. 输入:数据集
    10. 输出:数据集的香农熵
    11. 描述:计算给定数据集的香农熵
    12. """
    13. numEntries = len(dataSet)
    14. labelCounts = {}
    15. for featVec in dataSet:
    16. currentLabel = featVec[-1]
    17. if currentLabel not in labelCounts.keys():
    18. labelCounts[currentLabel] = 0
    19. labelCounts[currentLabel] += 1
    20. shannonEnt = 0.0
    21. for key in labelCounts:
    22. prob = float(labelCounts[key])/numEntries
    23. shannonEnt -= prob * log(prob, 2)
    24. return shannonEnt
    25. def splitDataSet(dataSet, axis, value):
    26. """
    27. 输入:数据集,选择维度,选择值
    28. 输出:划分数据集
    29. 描述:按照给定特征划分数据集;去除选择维度中等于选择值的项
    30. """
    31. retDataSet = []
    32. for featVec in dataSet:
    33. if featVec[axis] == value:# 判断此列axis的值是否为value
    34. reduceFeatVec = featVec[:axis]# 此行数据的前axis
    35. reduceFeatVec.extend(featVec[axis+1:])# axis列之后的数据
    36. retDataSet.append(reduceFeatVec)# 注意extendappend的区别。
    37. # 三句合写为一句
    38. # retDataSet.append(featVec[:axis] + featVec[axis + 1:])
    39. # a、b为例,extend,append是在a原地址操作,改变的是a。
    40. # extend:去掉列表b最外层的[],然后追加到a。append:将整个列表b作为一个值来添加。
    41. # +:新的变量c来实现相加,相加的过程和extend一样,但不是在被加的对象的地址上操作的。
    42. return retDataSet
    43. def chooseBestFeatureToSplit(dataSet):
    44. """
    45. 输入:数据集
    46. 输出:最好的划分维度
    47. 描述:选择最好的数据集划分维度
    48. """
    49. numFeatures = len(dataSet[0]) - 1
    50. baseEntropy = calcShannonEnt(dataSet)
    51. bestInfoGain = 0.0
    52. bestFeature = -1
    53. for i in range(numFeatures):
    54. featList = [example[i] for example in dataSet]
    55. uniqueVals = set(featList)
    56. newEntropy = 0.0
    57. for value in uniqueVals:
    58. subDataSet = splitDataSet(dataSet, i, value)
    59. prob = len(subDataSet)/float(len(dataSet))
    60. newEntropy += prob * calcShannonEnt(subDataSet)
    61. infoGain = baseEntropy - newEntropy
    62. if (infoGain > bestInfoGain):
    63. bestInfoGain = infoGain
    64. bestFeature = i
    65. return bestFeature
    66. def majorityCnt(classList):
    67. """
    68. 输入:分类类别列表
    69. 输出:子节点的分类
    70. 描述:数据集已经处理了所有属性,但是类标签依然不是唯一的,
    71. 采用多数判决的方法决定该子节点的分类
    72. """
    73. classCount = {}
    74. for vote in classList:
    75. if vote not in classCount.keys():
    76. classCount[vote] = 0
    77. classCount[vote] += 1
    78. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)
    79. return sortedClassCount[0][0]
    80. def createTree(dataSet, labels):
    81. """
    82. 输入:数据集,特征标签
    83. 输出:决策树
    84. 描述:递归构建决策树,利用上述的函数
    85. """
    86. classList = [example[-1] for example in dataSet]
    87. if classList.count(classList[0]) == len(classList):
    88. # 类别完全相同,停止划分
    89. return classList[0]
    90. if len(dataSet[0]) == 1:
    91. # 遍历完所有特征时返回出现次数最多的
    92. return majorityCnt(classList)
    93. bestFeat = chooseBestFeatureToSplit(dataSet)
    94. bestFeatLabel = labels[bestFeat]
    95. myTree = {bestFeatLabel:{}}
    96. del(labels[bestFeat])
    97. # 得到列表包括节点所有的属性值
    98. featValues = [example[bestFeat] for example in dataSet]
    99. uniqueVals = set(featValues)
    100. for value in uniqueVals:
    101. subLabels = labels[:]
    102. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    103. return myTree
    104. def classify(inputTree, featLabels, testVec):
    105. """
    106. 输入:决策树,分类标签,测试数据
    107. 输出:决策结果
    108. 描述:跑决策树
    109. """
    110. firstStr = list(inputTree.keys())[0]
    111. secondDict = inputTree[firstStr]
    112. featIndex = featLabels.index(firstStr)
    113. for key in secondDict.keys():
    114. if testVec[featIndex] == key:
    115. if type(secondDict[key]).__name__ == 'dict':
    116. classLabel = classify(secondDict[key], featLabels, testVec)
    117. else:
    118. classLabel = secondDict[key]
    119. return classLabel
    120. def classifyAll(inputTree, featLabels, testDataSet):
    121. """
    122. 输入:决策树,分类标签,测试数据集
    123. 输出:决策结果
    124. 描述:跑决策树
    125. """
    126. classLabelAll = []
    127. for testVec in testDataSet:
    128. classLabelAll.append(classify(inputTree, featLabels, testVec))
    129. return classLabelAll
    130. def storeTree(inputTree, filename):
    131. """
    132. 输入:决策树,保存文件路径
    133. 输出:
    134. 描述:保存决策树到文件
    135. """
    136. import pickle
    137. fw = open(filename, 'wb')
    138. pickle.dump(inputTree, fw)
    139. fw.close()
    140. def grabTree(filename):
    141. """
    142. 输入:文件路径名
    143. 输出:决策树
    144. 描述:从文件读取决策树
    145. """
    146. import pickle
    147. fr = open(filename, 'rb')
    148. return pickle.load(fr)
    149. def createDataSet():
    150. """
    151. outlook-> 0: sunny | 1: overcast | 2: rain
    152. temperature-> 0: hot | 1: mild | 2: cool
    153. humidity-> 0: high | 1: normal
    154. windy-> 0: false | 1: true
    155. """
    156. dataSet = [[0, 0, 0, 0, 'N'],
    157. [0, 0, 0, 1, 'N'],
    158. [1, 0, 0, 0, 'Y'],
    159. [2, 1, 0, 0, 'Y'],
    160. [2, 2, 1, 0, 'Y'],
    161. [2, 2, 1, 1, 'N'],
    162. [1, 2, 1, 1, 'Y']]
    163. labels = ['outlook', 'temperature', 'humidity', 'windy']
    164. return dataSet, labels
    165. def createTestSet():
    166. """
    167. outlook-> 0: sunny | 1: overcast | 2: rain
    168. temperature-> 0: hot | 1: mild | 2: cool
    169. humidity-> 0: high | 1: normal
    170. windy-> 0: false | 1: true
    171. """
    172. testSet = [[0, 1, 0, 0],
    173. [0, 2, 1, 0],
    174. [2, 1, 1, 0],
    175. [0, 1, 1, 1],
    176. [1, 1, 0, 1],
    177. [1, 0, 1, 0],
    178. [2, 1, 0, 1]]
    179. return testSet
    180. def main():
    181. dataSet, labels = createDataSet()
    182. labels_tmp = labels[:] # 拷贝,createTree会改变labels
    183. desicionTree = createTree(dataSet, labels_tmp)
    184. #storeTree(desicionTree, 'classifierStorage.txt')
    185. #desicionTree = grabTree('classifierStorage.txt')
    186. print('desicionTree:\n', desicionTree)
    187. treePlotter.createPlot(desicionTree)
    188. testSet = createTestSet()
    189. print('classifyResult:\n', classifyAll(desicionTree, labels, testSet))
    190. if __name__ == '__main__':
    191. main()