regTrees.pytreeExplore.py

CART算法和回归树

回归树的构建

  1. from numpy import *
  2. def loadDataSet(fileName): #general function to parse tab -delimited floats
  3. dataMat = [] #assume last column is target value
  4. fr = open(fileName)
  5. for line in fr.readlines():
  6. curLine = line.strip().split('\t')
  7. fltLine = list(map(float,curLine)) #map all elements to float()
  8. dataMat.append(fltLine)
  9. return dataMat
  10. def binSplitDataSet(dataSet, feature, value):
  11. mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
  12. mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
  13. return mat0,mat1
  14. def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
  15. feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
  16. if feat == None: return val #if the splitting hit a stop condition return val
  17. retTree = {}
  18. retTree['spInd'] = feat
  19. retTree['spVal'] = val
  20. lSet, rSet = binSplitDataSet(dataSet, feat, val)
  21. retTree['left'] = createTree(lSet, leafType, errType, ops)
  22. retTree['right'] = createTree(rSet, leafType, errType, ops)
  23. return retTree

image.png

  1. def regLeaf(dataSet):#returns the value used for each leaf
  2. return mean(dataSet[:,-1])
  3. def regErr(dataSet):
  4. return var(dataSet[:,-1]) * shape(dataSet)[0]
  5. def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
  6. tolS = ops[0]; tolN = ops[1]
  7. #if all the target variables are the same value: quit and return value
  8. if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
  9. return None, leafType(dataSet)
  10. m,n = shape(dataSet)
  11. #the choice of the best feature is driven by Reduction in RSS error from mean
  12. S = errType(dataSet)
  13. bestS = inf; bestIndex = 0; bestValue = 0
  14. for featIndex in range(n-1):
  15. for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
  16. mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
  17. if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
  18. newS = errType(mat0) + errType(mat1)
  19. if newS < bestS:
  20. bestIndex = featIndex
  21. bestValue = splitVal
  22. bestS = newS
  23. #if the decrease (S-bestS) is less than a threshold don't do the split
  24. if (S - bestS) < tolS:
  25. return None, leafType(dataSet) #exit cond 2
  26. mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
  27. if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
  28. return None, leafType(dataSet)
  29. return bestIndex,bestValue#returns the best feature to split on
  30. #and the value used for that split

image.png

已完成对回归树的构建,下面将进行树剪枝以达到更好的预测结果(尽量避免“过拟合”)

树剪枝

image.png

以上的例子说明tols对误差条件很敏感。产生了过多的叶节点。

  1. def isTree(obj):
  2. return (type(obj).__name__=='dict')
  3. def getMean(tree):
  4. if isTree(tree['right']): tree['right'] = getMean(tree['right'])
  5. if isTree(tree['left']): tree['left'] = getMean(tree['left'])
  6. return (tree['left']+tree['right'])/2.0
  7. def prune(tree, testData):
  8. if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
  9. if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
  10. lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  11. if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
  12. if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
  13. #if they are now both leafs, see if we can merge them
  14. if not isTree(tree['left']) and not isTree(tree['right']):
  15. lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  16. errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
  17. sum(power(rSet[:,-1] - tree['right'],2))
  18. treeMean = (tree['left']+tree['right'])/2.0
  19. errorMerge = sum(power(testData[:,-1] - treeMean,2))
  20. if errorMerge < errorNoMerge:
  21. print("merging")
  22. return treeMean
  23. else: return tree
  24. else: return tree

image.pngimage.png

模型树

  1. def linearSolve(dataSet): #helper function used in two places
  2. m,n = shape(dataSet)
  3. X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
  4. X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
  5. xTx = X.T*X
  6. if linalg.det(xTx) == 0.0:
  7. raise NameError('This matrix is singular, cannot do inverse,\n\
  8. try increasing the second value of ops')
  9. ws = xTx.I * (X.T * Y)
  10. return ws,X,Y
  11. def modelLeaf(dataSet):#create linear model and return coeficients
  12. ws,X,Y = linearSolve(dataSet)
  13. return ws
  14. def modelErr(dataSet):
  15. ws,X,Y = linearSolve(dataSet)
  16. yHat = X * ws
  17. return sum(power(Y - yHat,2))

image.png

示例:树回归与标准回归的比较

  1. def regTreeEval(model, inDat):
  2. return float(model)
  3. def modelTreeEval(model, inDat):
  4. n = shape(inDat)[1]
  5. X = mat(ones((1,n+1)))
  6. X[:,1:n+1]=inDat
  7. return float(X*model)
  8. def treeForeCast(tree, inData, modelEval=regTreeEval):
  9. if not isTree(tree): return modelEval(tree, inData)
  10. if inData[tree['spInd']] > tree['spVal']:
  11. if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
  12. else: return modelEval(tree['left'], inData)
  13. else:
  14. if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
  15. else: return modelEval(tree['right'], inData)
  16. def createForeCast(tree, testData, modelEval=regTreeEval):
  17. m=len(testData)
  18. yHat = mat(zeros((m,1)))
  19. for i in range(m):
  20. yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
  21. return yHat

image.png
image.png
image.png

image.png