regression.py

  1. from numpy import *
  2. def loadDataSet(fileName): #general function to parse tab -delimited floats
  3. numFeat = len(open(fileName).readline().split('\t')) - 1 #get number of fields
  4. dataMat = []; labelMat = []
  5. fr = open(fileName)
  6. for line in fr.readlines():
  7. lineArr =[]
  8. curLine = line.strip().split('\t')
  9. for i in range(numFeat):
  10. lineArr.append(float(curLine[i]))
  11. dataMat.append(lineArr)
  12. labelMat.append(float(curLine[-1]))
  13. return dataMat,labelMat
  14. def standRegres(xArr,yArr):
  15. xMat = mat(xArr); yMat = mat(yArr).T
  16. xTx = xMat.T*xMat
  17. if linalg.det(xTx) == 0.0:
  18. print("This matrix is singular, cannot do inverse")
  19. return
  20. ws = xTx.I * (xMat.T*yMat)
  21. return ws

image.png
image.png
image.png

局部加权线性回归

(根据数据来局部调整预测)

  1. def lwlr(testPoint,xArr,yArr,k=1.0):
  2. xMat = mat(xArr); yMat = mat(yArr).T
  3. m = shape(xMat)[0]
  4. weights = mat(eye((m)))
  5. for j in range(m): #next 2 lines create weights matrix
  6. diffMat = testPoint - xMat[j,:] #
  7. weights[j,j] = exp(diffMat*diffMat.T/(-2.0*k**2))
  8. xTx = xMat.T * (weights * xMat)
  9. if linalg.det(xTx) == 0.0:
  10. print("This matrix is singular, cannot do inverse")
  11. return
  12. ws = xTx.I * (xMat.T * (weights * yMat))
  13. return testPoint * ws
  14. def lwlrTest(testArr,xArr,yArr,k=1.0): #loops over all the data points and applies lwlr to each one
  15. m = shape(testArr)[0]
  16. yHat = zeros(m)
  17. for i in range(m):
  18. yHat[i] = lwlr(testArr[i],xArr,yArr,k)
  19. return yHat

image.png

缩减系数来理解数据

当特征大于样本数量,则无法再计算,因为此时矩阵为非满秩矩阵,计算逆会出现问题

岭回归

通过加入对角线为1其余为0的单位矩阵的系数倍数使得矩阵非奇异。
能够减少不重要的参数。

  1. def ridgeRegres(xMat,yMat,lam=0.2):
  2. xTx = xMat.T*xMat
  3. denom = xTx + eye(shape(xMat)[1])*lam
  4. if linalg.det(denom) == 0.0:
  5. print("This matrix is singular, cannot do inverse")
  6. return
  7. ws = denom.I * (xMat.T*yMat)
  8. return ws
  9. def ridgeTest(xArr,yArr):
  10. xMat = mat(xArr); yMat=mat(yArr).T
  11. yMean = mean(yMat,0)
  12. yMat = yMat - yMean #to eliminate X0 take mean off of Y
  13. #regularize X's
  14. xMeans = mean(xMat,0) #calc mean then subtract it off
  15. xVar = var(xMat,0) #calc variance of Xi then divide by it
  16. xMat = (xMat - xMeans)/xVar
  17. numTestPts = 30
  18. wMat = zeros((numTestPts,shape(xMat)[1]))
  19. for i in range(numTestPts):
  20. ws = ridgeRegres(xMat,yMat,exp(i-10))
  21. wMat[i,:]=ws.T
  22. return wMat

image.png
image.png

前向逐步线性回归

  1. def stageWise(xArr,yArr,eps=0.01,numIt=100):
  2. xMat = mat(xArr); yMat=mat(yArr).T
  3. yMean = mean(yMat,0)
  4. yMat = yMat - yMean #can also regularize ys but will get smaller coef
  5. xMat = regularize(xMat)
  6. m,n=shape(xMat)
  7. returnMat = zeros((numIt,n)) #testing code remove
  8. ws = zeros((n,1)); wsTest = ws.copy(); wsMax = ws.copy()
  9. for i in range(numIt):
  10. print(ws.T)
  11. lowestError = inf;
  12. for j in range(n):
  13. for sign in [-1,1]:
  14. wsTest = ws.copy()
  15. wsTest[j] += eps*sign
  16. yTest = xMat*wsTest
  17. rssE = rssError(yMat.A,yTest.A)
  18. if rssE < lowestError:
  19. lowestError = rssE
  20. wsMax = wsTest
  21. ws = wsMax.copy()
  22. returnMat[i,:]=ws.T
  23. return returnMat

image.png

权衡偏差和方差

示例:预测乐高玩具套装的价格

获取数据

  1. from time import sleep
  2. import json
  3. import urllib.request
  4. def searchForSet(retX, retY, setNum, yr, numPce, origPrc):
  5. sleep(10)
  6. myAPIstr = 'AIzaSyD2cR2KFyx12hXu6PFU-wrWot3NXvko8vY'
  7. searchURL = 'https://www.googleapis.com/shopping/search/v1/public/products?key=%s&country=US&q=lego+%d&alt=json' % (myAPIstr, setNum)
  8. pg = urllib.request.urlopen(searchURL)
  9. retDict = json.loads(pg.read())
  10. for i in range(len(retDict['items'])):
  11. try:
  12. currItem = retDict['items'][i]
  13. if currItem['product']['condition'] == 'new':
  14. newFlag = 1
  15. else: newFlag = 0
  16. listOfInv = currItem['product']['inventories']
  17. for item in listOfInv:
  18. sellingPrice = item['price']
  19. if sellingPrice > origPrc * 0.5:
  20. print("%d\t%d\t%d\t%f\t%f" % (yr,numPce,newFlag,origPrc, sellingPrice))
  21. retX.append([yr, numPce, newFlag, origPrc])
  22. retY.append(sellingPrice)
  23. except: print('problem with item %d' % i)
  24. def setDataCollect(retX, retY):
  25. searchForSet(retX, retY, 8288, 2006, 800, 49.99)
  26. searchForSet(retX, retY, 10030, 2002, 3096, 269.99)
  27. searchForSet(retX, retY, 10179, 2007, 5195, 499.99)
  28. searchForSet(retX, retY, 10181, 2007, 3428, 199.99)
  29. searchForSet(retX, retY, 10189, 2008, 5922, 299.99)
  30. searchForSet(retX, retY, 10196, 2009, 3263, 249.99)