一.决策树的构造

image.png

image.png

  1. from __future__ import print_function
  2. print(__doc__)
  3. import operator
  4. from math import log
  5. import decisionTreePlot as dtPlot
  6. from collections import Counter
  7. def createDataSet():
  8. labels = ['no surfacing', 'flippers']
  9. return dataSet, labels
  10. def calcShannonEnt(dataSet):
  11. numEntries = len(dataSet)
  12. labelCounts = {}
  13. for featVec in dataSet:
  14. currentLabel = featVec[-1]
  15. if currentLabel not in labelCounts.keys():
  16. labelCounts[currentLabel] = 0
  17. labelCounts[currentLabel] += 1
  18. # 对于label标签的占比,求出label标签的香农熵
  19. shannonEnt = 0.0
  20. for key in labelCounts:
  21. prob = float(labelCounts[key])/numEntries
  22. shannonEnt -= prob * log(prob, 2)
  23. return shannonEnt
  24. def splitDataSet(dataSet, index, value):
  25. retDataSet = []
  26. for featVec in dataSet:
  27. if featVec[index] == value:
  28. reducedFeatVec = featVec[:index]
  29. reducedFeatVec.extend(featVec[index+1:])
  30. retDataSet.append(reducedFeatVec)
  31. return retDataSet
  32. def chooseBestFeatureToSplit(dataSet):
  33. numFeatures = len(dataSet[0]) - 1
  34. baseEntropy = calcShannonEnt(dataSet)
  35. bestInfoGain, bestFeature = 0.0, -1
  36. for i in range(numFeatures):
  37. featList = [example[i] for example in dataSet]
  38. uniqueVals = set(featList)
  39. newEntropy = 0.0
  40. for value in uniqueVals:
  41. subDataSet = splitDataSet(dataSet, i, value)
  42. prob = len(subDataSet)/float(len(dataSet))
  43. newEntropy += prob * calcShannonEnt(subDataSet)
  44. infoGain = baseEntropy - newEntropy
  45. print('infoGain=', infoGain, 'bestFeature=', i, baseEntropy, newEntropy)
  46. if (infoGain > bestInfoGain):
  47. bestInfoGain = infoGain
  48. bestFeature = i
  49. return bestFeature
  50. def majorityCnt(classList):
  51. classCount = {}
  52. for vote in classList:
  53. if vote not in classCount.keys():
  54. classCount[vote] = 0
  55. classCount[vote] += 1
  56. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  57. return sortedClassCount[0][0]
  58. def createTree(dataSet, labels):
  59. classList = [example[-1] for example in dataSet]
  60. if classList.count(classList[0]) == len(classList):
  61. return classList[0]
  62. if len(dataSet[0]) == 1:
  63. return majorityCnt(classList)
  64. bestFeat = chooseBestFeatureToSplit(dataSet)
  65. bestFeatLabel = labels[bestFeat]
  66. myTree = {bestFeatLabel: {}}
  67. del(labels[bestFeat])
  68. featValues = [example[bestFeat] for example in dataSet]
  69. uniqueVals = set(featValues)
  70. for value in uniqueVals:
  71. subLabels = labels[:]
  72. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
  73. return myTree
  74. def classify(inputTree, featLabels, testVec):
  75. firstStr = inputTree.keys()[0]
  76. secondDict = inputTree[firstStr]
  77. featIndex = featLabels.index(firstStr)
  78. key = testVec[featIndex]
  79. valueOfFeat = secondDict[key]
  80. print('+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat)
  81. if isinstance(valueOfFeat, dict):
  82. classLabel = classify(valueOfFeat, featLabels, testVec)
  83. else:
  84. classLabel = valueOfFeat
  85. return classLabel
  86. def storeTree(inputTree, filename):
  87. import pickle
  88. fw = open(filename, 'wb')
  89. pickle.dump(inputTree, fw)
  90. fw.close()
  91. with open(filename, 'wb') as fw:
  92. pickle.dump(inputTree, fw)
  93. def grabTree(filename):
  94. import pickle
  95. fr = open(filename,'rb')
  96. return pickle.load(fr)
  97. def fishTest():
  98. import copy
  99. myTree = createTree(myDat, copy.deepcopy(labels))
  100. print(myTree)
  101. print(classify(myTree, labels, [1, 1]))
  102. # 获得树的高度
  103. print(get_tree_height(myTree))
  104. # 画图可视化展现
  105. dtPlot.createPlot(myTree)
  106. def ContactLensesTest():
  107. fr = open('data/3.DecisionTree/lenses.txt')
  108. lenses = [inst.strip().split('\t') for inst in fr.readlines()]
  109. lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
  110. lensesTree = createTree(lenses, lensesLabels)
  111. print(lensesTree)
  112. # 画图可视化展现
  113. dtPlot.createPlot(lensesTree)
  114. def get_tree_height(tree):
  115. if not isinstance(tree, dict):
  116. return 1
  117. child_trees = tree.values()[0].values()
  118. # 遍历子树, 获得子树的最大高度
  119. max_height = 0
  120. for child_tree in child_trees:
  121. child_tree_height = get_tree_height(child_tree)
  122. if child_tree_height > max_height:
  123. max_height = child_tree_height
  124. return max_height + 1
  125. if __name__ == "__main__":
  126. fishTest()
  127. # ContactLensesTest()