import numpy as npimport matplotlib.pyplot as pltp=[[1,0.3,2,1],[2,0.3,1,1]]n=[[1.5,1.7,2,1.5],[2,1.5,2,2.5]]p=np.array(p)n=np.array(n)def divide(dist,k,X,Y): ans_p=[np.sort(dist(p[0]-X[i],p[1]-Y[i]))for i in range(len(X))] ans_n=[np.sort(dist(n[0]-X[i],n[1]-Y[i]))for i in range(len(X))] t=[ans_p[i][int((k-1)/2)]>ans_n[i][int((k-1)/2)]for i in range(len(ans_p))] return np.array(t)def dist1(x,y): return abs(x)+abs(y)def dist2(x,y): return np.sqrt(x*x+y*y)def plot(dist,k,ax): N=200 X=np.linspace(-0,3,N) Y=X X,Y=np.meshgrid(X,Y) X=X.reshape(1,N*N)[0] Y=Y.reshape(1,N*N)[0] predict=divide(dist,3,X,Y) ax.contourf(X.reshape(N,N),Y.reshape(N,N),predict.reshape(N,N),cmap=plt.cm.Spectral,alpha=0.3) ax.plot(p[0],p[1],'rx') ax.plot(n[0],n[1],'bo') plt.text(0.5,2.5,"k="+str(k))# plot(dist1,3,plt)plot(dist2,3,plt)plt.show()

