trees.py

ID3算法:

ID3算法 - 维基百科,自由的百科全书.pdf

香农熵:

image.png

计算给定数据集的香农熵:

  1. from math import log
  2. def calcShannonEnt(dataSet):
  3. numEntries = len(dataSet)
  4. labelCounts = {}
  5. for featVec in dataSet: #the the number of unique elements and their occurance
  6. currentLabel = featVec[-1]
  7. if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
  8. labelCounts[currentLabel] += 1
  9. shannonEnt = 0.0
  10. for key in labelCounts:
  11. prob = float(labelCounts[key])/numEntries
  12. shannonEnt -= prob * log(prob, 2) #log base 2
  13. return shannonEnt

即:输入一个数据集,计算其中每个类别的概率,代入公式求香农熵。

如计算示例中的数据集的香农熵
image.png
熵随着类别增加而变大:
image.png

按照给定特征划分数据集:

  1. def splitDataSet(dataSet, axis, value):
  2. retDataSet = []
  3. for featVec in dataSet:
  4. if featVec[axis] == value:
  5. reducedFeatVec = featVec[:axis] #chop out axis used for splitting
  6. reducedFeatVec.extend(featVec[axis+1:])
  7. retDataSet.append(reducedFeatVec)
  8. return retDataSet

(dataSet, axis, value) 参数:数据集,划分数据集的特征,需要返回的特征的值

??为什么不直接写reducedFeatVec = featVec[:],有什么区别?
>>>>>因为切片区间是左闭右开!!!!!!

选择最好的数据集划分方式:

  1. def chooseBestFeatureToSplit(dataSet):
  2. numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels
  3. baseEntropy = calcShannonEnt(dataSet)
  4. bestInfoGain = 0.0; bestFeature = -1
  5. for i in range(numFeatures): #iterate over all the features
  6. featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
  7. uniqueVals = set(featList) #get a set of unique values
  8. newEntropy = 0.0
  9. for value in uniqueVals:
  10. subDataSet = splitDataSet(dataSet, i, value)
  11. prob = len(subDataSet)/float(len(dataSet))
  12. newEntropy += prob * calcShannonEnt(subDataSet)
  13. infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy
  14. if (infoGain > bestInfoGain): #compare this to the best gain so far
  15. bestInfoGain = infoGain #if better than current best, set to best
  16. bestFeature = i
  17. return bestFeature #returns an integer

python中set()函数的用法

上面的代码分别用每一个特征划分数据集,计算与原始的数据集之间的香农熵的差,用来比较信息增益,选取出最大的信息增益的特征。

  1. def majorityCnt(classList):
  2. classCount={}
  3. for vote in classList:
  4. if vote not in classCount.keys(): classCount[vote] = 0
  5. classCount[vote] += 1
  6. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
  7. return sortedClassCount[0][0]
  1. def createTree(dataSet, labels):
  2. classList = [example[-1] for example in dataSet]
  3. if classList.count(classList[0]) == len(classList):
  4. return classList[0]#stop splitting when all of the classes are equal
  5. if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
  6. return majorityCnt(classList)
  7. bestFeat = chooseBestFeatureToSplit(dataSet)
  8. bestFeatLabel = labels[bestFeat]
  9. myTree = {bestFeatLabel:{}}
  10. del(labels[bestFeat])
  11. featValues = [example[bestFeat] for example in dataSet]
  12. uniqueVals = set(featValues)
  13. for value in uniqueVals:
  14. subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels
  15. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
  16. return myTree

image.png

使用matplotlib绘制树:

treePlotter.py

  1. import matplotlib.pyplot as plt
  2. #定义文本框和箭头格式:
  3. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  4. leafNode = dict(boxstyle="round4", fc="0.8")
  5. arrow_args = dict(arrowstyle="<-")
  6. #绘制带箭头的注解:
  7. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  8. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  9. xytext=centerPt, textcoords='axes fraction',
  10. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  11. def createPlot(inTree):
  12. fig = plt.figure(1, facecolor='white')
  13. fig.clf()
  14. axprops = dict(xticks=[], yticks=[])
  15. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
  16. #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
  17. plotTree.totalW = float(getNumLeafs(inTree))
  18. plotTree.totalD = float(getTreeDepth(inTree))
  19. plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
  20. plotTree(inTree, (0.5, 1.0), '')
  21. plt.show()
  1. def getNumLeafs(myTree):
  2. numLeafs = 0
  3. firstStr = list(myTree)[0]
  4. secondDict = myTree[firstStr]
  5. for key in secondDict.keys():
  6. if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
  7. numLeafs += getNumLeafs(secondDict[key])
  8. else: numLeafs += 1
  9. return numLeafs
  10. def getTreeDepth(myTree):
  11. maxDepth = 0
  12. firstStr = list(myTree)[0]
  13. secondDict = myTree[firstStr]
  14. for key in secondDict.keys():
  15. if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
  16. thisDepth = 1 + getTreeDepth(secondDict[key])
  17. else: thisDepth = 1
  18. if thisDepth > maxDepth: maxDepth = thisDepth
  19. return maxDepth

image.png
image.png

测试算法:使用决策树执行分类

  1. def classify(inputTree, featLabels, testVec):
  2. firstStr = list(inputTree)[0]
  3. secondDict = inputTree[firstStr]
  4. featIndex = featLabels.index(firstStr)
  5. key = testVec[featIndex]
  6. valueOfFeat = secondDict[key]
  7. if isinstance(valueOfFeat, dict):
  8. classLabel = classify(valueOfFeat, featLabels, testVec)
  9. else: classLabel = valueOfFeat
  10. return classLabel

使用算法:决策树的构建

  1. def storeTree(inputTree, filename):
  2. import pickle
  3. fw = open(filename, 'wb')
  4. pickle.dump(inputTree, fw)
  5. fw.close()
  6. def grabTree(filename):
  7. import pickle
  8. fr = open(filename, 'rb')
  9. return pickle.load(fr)

image.png

测试隐形眼镜数据:
image.png
image.png