一、线性扫描法
from tensorflow import kerasimport numpy as npimport logginglogging.basicConfig(level=logging.INFO,format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')logger = logging.getLogger(__name__)def calDist(x1,x2): ''' 计算两个点之间的距离 ''' # 欧氏距离 return np.sqrt(np.sum(np.square(x1-x2))) # 等同于求差值向量的l2范数 # return np.linalg.norm(x1-x2)def getClostest(trainData,trainLabel,x,topK): ''' 预测样本x的类别 ''' # distList = np.zeros(trainLabel.shape[0]) distList = [0] * len(trainLabel) for i in range(len(trainData)): xi = trainData[i] iDist = calDist(xi,x) distList[i] = iDist # 返回前k近的样本的序号 topKList = np.argsort(distList)[:topK] labelList = [0] *10 for index in topKList: labelList[int(trainLabel[index])]+=1 result = labelList.index(max(labelList)) return resultif __name__ == "__main__": (x_train, y_train), (x_test, y_test)=keras.datasets.mnist.load_data() errCnt = 0 for i in range(100): x = x_test[i] y = getClostest(x_train,y_train,x,25) logger.info((y,y_test[i])) if y!=y_test[i]: errCnt+=1 logger.info('correct rate:%d%%' % (100-errCnt))
二、kd树