一、线性扫描法

  1. from tensorflow import keras
  2. import numpy as np
  3. import logging
  4. logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  5. logger = logging.getLogger(__name__)
  6. def calDist(x1,x2):
  7. '''
  8. 计算两个点之间的距离
  9. '''
  10. # 欧氏距离
  11. return np.sqrt(np.sum(np.square(x1-x2)))
  12. # 等同于求差值向量的l2范数
  13. # return np.linalg.norm(x1-x2)
  14. def getClostest(trainData,trainLabel,x,topK):
  15. '''
  16. 预测样本x的类别
  17. '''
  18. # distList = np.zeros(trainLabel.shape[0])
  19. distList = [0] * len(trainLabel)
  20. for i in range(len(trainData)):
  21. xi = trainData[i]
  22. iDist = calDist(xi,x)
  23. distList[i] = iDist
  24. # 返回前k近的样本的序号
  25. topKList = np.argsort(distList)[:topK]
  26. labelList = [0] *10
  27. for index in topKList:
  28. labelList[int(trainLabel[index])]+=1
  29. result = labelList.index(max(labelList))
  30. return result
  31. if __name__ == "__main__":
  32. (x_train, y_train), (x_test, y_test)=keras.datasets.mnist.load_data()
  33. errCnt = 0
  34. for i in range(100):
  35. x = x_test[i]
  36. y = getClostest(x_train,y_train,x,25)
  37. logger.info((y,y_test[i]))
  38. if y!=y_test[i]: errCnt+=1
  39. logger.info('correct rate:%d%%' % (100-errCnt))

二、kd树