from sklearn.metrics import mean_square_error
def plot_learning_curve(algo,X_train,X_test,y_train,y_test): # algo的意思是学习器
train_score = []
test_score = []
for i in range(1, len(X_train)+1):
algo.fit(X_train[:i],y_train[:i])
y_train_predict = alog.predict(X_train[:i])
train_score.append(mean_square_error(y_train[:i], y_train_predict))
y_test_predict = alog.predict(X_test)
test_score.append(mean_square_error(y_test, y_test_predict))
plt.plot([i for i in range(1,len(X_train)+1)], np.sqrt(train_score),label='train')
plt.plot([i for i in range(1,len(X_test)+1)], np.sqrt(test_score),label='train')
plt.legend() # 添加标签需要
plt.axis([0, len(X_train)+1,0,4]) # 绘制范围,前面俩个是横轴,后面俩个是纵轴
plt.show()
plot_learning_curve(Linearregression(), X_train,X_test,y_train,y_test)