欠拟合

image.png

过拟合

image.png

模型复杂度与拟合程度

image.png

测试数据集的意义

image.png

  1. from sklearn.model_selection import train_test_split
  2. X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
  3. lin_reg = LinearRegression()
  4. lin_reg.fit(X_train, y_train)
  5. y_predict = lin_reg.predict(X_test)
  6. mean_squared_error(y_test, y_predict) # 2.2199965269396573
  7. poly2_reg = PolynomialRegression(degree=2)
  8. poly2_reg.fit(X_train, y_train)
  9. y2_predict = poly2_reg.predict(X_test)
  10. mean_squared_error(y_test, y2_predict) # 0.80356410562978997
  11. poly10_reg = PolynomialRegression(degree=10)
  12. poly10_reg.fit(X_train, y_train)
  13. y10_predict = poly10_reg.predict(X_test)
  14. mean_squared_error(y_test, y10_predict) # 0.92129307221507939
  15. # 阶数过高,过拟合,泛化能力差
  16. poly100_reg = PolynomialRegression(degree=100)
  17. poly100_reg.fit(X_train, y_train)
  18. y100_predict = poly100_reg.predict(X_test)
  19. mean_squared_error(y_test, y100_predict) # 14075796419.234262