import numpy as np
import matplotlib.pyplot as plt
p=[[1,0.3,2,1],[2,0.3,1,1]]
n=[[1.5,1.7,2,1.5],[2,1.5,2,2.5]]
p=np.array(p)
n=np.array(n)
def divide(dist,k,X,Y):
ans_p=[np.sort(dist(p[0]-X[i],p[1]-Y[i]))for i in range(len(X))]
ans_n=[np.sort(dist(n[0]-X[i],n[1]-Y[i]))for i in range(len(X))]
t=[ans_p[i][int((k-1)/2)]>ans_n[i][int((k-1)/2)]for i in range(len(ans_p))]
return np.array(t)
def dist1(x,y):
return abs(x)+abs(y)
def dist2(x,y):
return np.sqrt(x*x+y*y)
def plot(dist,k,ax):
N=200
X=np.linspace(-0,3,N)
Y=X
X,Y=np.meshgrid(X,Y)
X=X.reshape(1,N*N)[0]
Y=Y.reshape(1,N*N)[0]
predict=divide(dist,3,X,Y)
ax.contourf(X.reshape(N,N),Y.reshape(N,N),predict.reshape(N,N),cmap=plt.cm.Spectral,alpha=0.3)
ax.plot(p[0],p[1],'rx')
ax.plot(n[0],n[1],'bo')
plt.text(0.5,2.5,"k="+str(k))
# plot(dist1,3,plt)
plot(dist2,3,plt)
plt.show()

