准备数据

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.linear_model import LinearRegression
  4. # 1. 数据展示:了解数据,可以是csv读取,也可以直接copy进来
  5. years = np.arange(2009, 2020)
  6. sales = np.array([0.52, 9.36, 33.6, 132, 352, 571, 912, 1207, 1682, 2135, 2684])
  7. print(years)
  8. print(sales)

可视化

一般使用散点图

  1. plt.scatter(years, sales, c='red')

image.png

初步判断

多项式回归(3阶)

  1. y = a*x^3 + b*x^2 + c*x + d
  2. 1: 1 1 1
  3. 2: 8 4 2
  4. 3: 27 9 3

数据预处理

  1. model_y = sales
  2. model_x = (years - 2008).reshape(-1, 1) # 任意行,一列
  3. print(model_x)
  4. model_x = np.concatenate([model_x ** 3, model_x ** 2, model_x], axis=1)
  5. print(model_x)

image.png

建模

  1. # 4. 创建回归模型(多项式->1元3次)
  2. model = LinearRegression()
  3. # 5. 数据训练
  4. model.fit(model_x, model_y)
  5. # 6. 获取系数、截距 -> 声明方程式
  6. print('系数:', model.coef_) # 系数: [ -0.20964258 34.42433566 -117.85390054]
  7. print('截距:', model.intercept_) # 截距: 90.12060606060629

y = -0.20964258*x^3 + 34.42433566*x^2 + -117.85390054*x + 90.12060606060629

绘图

  1. # 7. 添加趋势线:想象成画折线图,x:1~11,12 y:带入公式之后得到的
  2. trend_x = np.linspace(1, 12, 100)
  3. fun = lambda x: -0.20964258 * x ** 3 + 34.42433566 * x ** 2 + -117.85390054 * x + 90.12060606060629
  4. trend_y = fun(trend_x)
  5. # print(type(fun))
  6. # print(trend_x)
  7. # print(trend_y)
  8. years_no = years - 2008
  9. plt.scatter(years_no, sales, c='red') # 画散点图
  10. plt.plot(trend_x, trend_y, c='green') # 画趋势线
  11. plt.show()

image.png

预测

  1. # 8. 预测2020年的销售额
  2. print('2020年销售额预测:', fun(12))
  3. years_no = years - 2008
  4. plt.scatter(years_no, sales, c='red')
  5. plt.scatter(12, fun(12), c='blue')
  6. plt.plot(trend_x, trend_y, c='green')
  7. # 加数据标签
  8. plt.annotate(fun(12), xy=(12, fun(12))) # annotate : (参数1:注释文本的内容, 参数2:被注释的坐标点)

image.png
给其它坐标加标签

  1. for i in range(11):
  2. plt.annotate(sales[i], xy=(years_no[i], sales[i]))
  3. plt.show()

image.png

调整标签的位置

  1. # x右移0.5,y下移0.5
  2. plt.annotate(round(fun(12), 1), xy=(12, fun(12)), xytext=(12 + 0.5, fun(12) - 0.5))
  3. for i in range(11):
  4. plt.annotate(sales[i], xy=(years_no[i], sales[i]), xytext=(years_no[i] + 0.5, sales[i] - 0.5))
  5. plt.show()

image.png