#适用范围:全离散型变量# -*- coding: utf-8 -*-__author__ = 'Wsine'from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet): """ 输入:数据集 输出:数据集的香农熵 描述:计算给定数据集的香农熵 """ numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * log(prob, 2) return shannonEntdef splitDataSet(dataSet, axis, value): """ 输入:数据集,选择维度,选择值 输出:划分数据集 描述:按照给定特征划分数据集;去除选择维度中等于选择值的项 """ retDataSet = [] for featVec in dataSet: if featVec[axis] == value:# 判断此列axis的值是否为value reduceFeatVec = featVec[:axis]# 此行数据的前axis列 reduceFeatVec.extend(featVec[axis+1:])# axis列之后的数据 retDataSet.append(reduceFeatVec)# 注意extend与append的区别。 # 三句合写为一句 # retDataSet.append(featVec[:axis] + featVec[axis + 1:]) # a、b为例,extend,append是在a原地址操作,改变的是a。 # extend:去掉列表b最外层的[],然后追加到a。append:将整个列表b作为一个值来添加。 # +:新的变量c来实现相加,相加的过程和extend一样,但不是在被加的对象的地址上操作的。 return retDataSetdef chooseBestFeatureToSplit(dataSet): """ 输入:数据集 输出:最好的划分维度 描述:选择最好的数据集划分维度 """ numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeaturedef majorityCnt(classList): """ 输入:分类类别列表 输出:子节点的分类 描述:数据集已经处理了所有属性,但是类标签依然不是唯一的, 采用多数判决的方法决定该子节点的分类 """ classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True) return sortedClassCount[0][0]def createTree(dataSet, labels): """ 输入:数据集,特征标签 输出:决策树 描述:递归构建决策树,利用上述的函数 """ classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): # 类别完全相同,停止划分 return classList[0] if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的 return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) # 得到列表包括节点所有的属性值 featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTreedef classify(inputTree, featLabels, testVec): """ 输入:决策树,分类标签,测试数据 输出:决策结果 描述:跑决策树 """ firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabeldef classifyAll(inputTree, featLabels, testDataSet): """ 输入:决策树,分类标签,测试数据集 输出:决策结果 描述:跑决策树 """ classLabelAll = [] for testVec in testDataSet: classLabelAll.append(classify(inputTree, featLabels, testVec)) return classLabelAlldef storeTree(inputTree, filename): """ 输入:决策树,保存文件路径 输出: 描述:保存决策树到文件 """ import pickle fw = open(filename, 'wb') pickle.dump(inputTree, fw) fw.close()def grabTree(filename): """ 输入:文件路径名 输出:决策树 描述:从文件读取决策树 """ import pickle fr = open(filename, 'rb') return pickle.load(fr)def createDataSet(): """ outlook-> 0: sunny | 1: overcast | 2: rain temperature-> 0: hot | 1: mild | 2: cool humidity-> 0: high | 1: normal windy-> 0: false | 1: true """ dataSet = [[0, 0, 0, 0, 'N'], [0, 0, 0, 1, 'N'], [1, 0, 0, 0, 'Y'], [2, 1, 0, 0, 'Y'], [2, 2, 1, 0, 'Y'], [2, 2, 1, 1, 'N'], [1, 2, 1, 1, 'Y']] labels = ['outlook', 'temperature', 'humidity', 'windy'] return dataSet, labelsdef createTestSet(): """ outlook-> 0: sunny | 1: overcast | 2: rain temperature-> 0: hot | 1: mild | 2: cool humidity-> 0: high | 1: normal windy-> 0: false | 1: true """ testSet = [[0, 1, 0, 0], [0, 2, 1, 0], [2, 1, 1, 0], [0, 1, 1, 1], [1, 1, 0, 1], [1, 0, 1, 0], [2, 1, 0, 1]] return testSetdef main(): dataSet, labels = createDataSet() labels_tmp = labels[:] # 拷贝,createTree会改变labels desicionTree = createTree(dataSet, labels_tmp) #storeTree(desicionTree, 'classifierStorage.txt') #desicionTree = grabTree('classifierStorage.txt') print('desicionTree:\n', desicionTree) treePlotter.createPlot(desicionTree) testSet = createTestSet() print('classifyResult:\n', classifyAll(desicionTree, labels, testSet))if __name__ == '__main__': main()