from sklearn.metrics import mean_square_errordef plot_learning_curve(algo,X_train,X_test,y_train,y_test): # algo的意思是学习器 train_score = [] test_score = [] for i in range(1, len(X_train)+1): algo.fit(X_train[:i],y_train[:i]) y_train_predict = alog.predict(X_train[:i]) train_score.append(mean_square_error(y_train[:i], y_train_predict)) y_test_predict = alog.predict(X_test) test_score.append(mean_square_error(y_test, y_test_predict)) plt.plot([i for i in range(1,len(X_train)+1)], np.sqrt(train_score),label='train') plt.plot([i for i in range(1,len(X_test)+1)], np.sqrt(test_score),label='train') plt.legend() # 添加标签需要 plt.axis([0, len(X_train)+1,0,4]) # 绘制范围,前面俩个是横轴,后面俩个是纵轴 plt.show() plot_learning_curve(Linearregression(), X_train,X_test,y_train,y_test)