1、公式法

通过公式法可以直接求出线性回归的最优解
image.png

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. X = 2 * np.random.rand(100, 1) # 100*1矩阵
  4. y = 12 + 5 * X + np.random.randn(100, 1)
  5. plt.plot(X, y, 'b.')
  6. plt.xlabel("X_1")
  7. plt.ylabel("y")
  8. plt.axis([0, 2, 10, 20])
  9. plt.show()
  10. number_data = len(X)
  11. X_b = np.c_[np.ones((100, 1)), X]
  12. theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
  13. print(theta_best)

2、梯度下降法

2.1批量梯度下降法(Batch Gradient Descent)

批量梯度下降法是最原始的形式,它是指在每一次迭代时使用所有样本来进行梯度的更新。

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. X = 2 * np.random.rand(100, 1) # 100*1矩阵
  4. y = 12 + 5 * X + np.random.randn(100, 1)
  5. plt.plot(X, y, 'b.')
  6. plt.xlabel("X_1")
  7. plt.ylabel("y")
  8. plt.axis([0, 2, 10, 20])
  9. plt.show()
  10. eta = 0.01
  11. n_iterations = 100
  12. number_data = len(X)
  13. X_b = np.c_[np.ones((number_data, 1)), X]
  14. # 随机设定模型参数
  15. theta = np.random.randn(2, 1)
  16. cost_history = []
  17. for iteration in range(n_iterations):
  18. # 计算损失
  19. cost = 1 / (2 * number_data) * (y - X_b.dot(theta)).T.dot(y - X_b.dot(theta))[0][0]
  20. cost_history.append(cost)
  21. # 计算梯度
  22. gradient = 2 / number_data * X_b.T.dot(X_b.dot(theta) - y)
  23. # 更新模型参数
  24. theta = theta - eta*gradient
  25. plt.plot(range(len(cost_history)), cost_history)
  26. plt.show()

2.2随机梯度下降法(Stochastic Gradient Descent,SGD)

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. X = 2 * np.random.rand(100, 1) # 100*1矩阵
  4. y = 12 + 5 * X + np.random.randn(100, 1)
  5. plt.plot(X, y, 'b.')
  6. plt.xlabel("X_1")
  7. plt.ylabel("y")
  8. plt.axis([0, 2, 10, 20])
  9. plt.show()
  10. number_data = len(X)
  11. X_b = np.c_[np.ones((number_data, 1)), X]
  12. theta = np.random.randn(2, 1)
  13. def learning_schedule(t):
  14. """
  15. 返回当前的学习率,学习率先大后小
  16. :param t:
  17. :return:
  18. """
  19. t0 = 5
  20. t1 = 50
  21. return t0 / (t1 + t)
  22. n_epochs = 50 # 学习迭代的次数
  23. m = len(X_b) # 样本的数量
  24. theta_path_sgd = [] # 保存theta更新路径
  25. cost_path_sgd = [] # 保存损失值的更新路径
  26. for epoch in range(n_epochs):
  27. for i in range(m):
  28. random_index = np.random.randint(m)
  29. xi = X_b[random_index:random_index + 1]
  30. yi = y[random_index:random_index + 1]
  31. gradient = 2 * xi.T.dot(xi.dot(theta) - yi)
  32. eta = learning_schedule(n_epochs * m + i)
  33. # 保存theta更新
  34. theta = theta - eta * gradient
  35. theta_path_sgd.append(theta)
  36. # 保存损失值更新
  37. cost = (y - X_b.dot(theta)).T.dot((y - X_b.dot(theta)))[0][0]
  38. cost_path_sgd.append(cost)
  39. plt.plot(range(len(cost_path_sgd)), cost_path_sgd)
  40. plt.show()

2.3 小批量梯度下降(Mini-Batch Gradient Descent, MBGD)

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. X = 2 * np.random.rand(100, 1) # 100*1矩阵
  4. y = 12 + 5 * X + np.random.randn(100, 1)
  5. plt.plot(X, y, 'b.')
  6. plt.xlabel("X_1")
  7. plt.ylabel("y")
  8. plt.axis([0, 2, 10, 20])
  9. plt.show()
  10. # eta = 0.01
  11. # n_iterations = 100
  12. number_data = len(X)
  13. X_b = np.c_[np.ones((number_data, 1)), X]
  14. def learning_schedule(t):
  15. """
  16. 返回当前的学习率,学习率先大后小
  17. :param t:
  18. :return:
  19. """
  20. t0 = 5
  21. t1 = 50
  22. return t0 / (t1 + t)
  23. n_epochs = 50 # 学习迭代的次数
  24. minibatch = 16
  25. theta = np.random.randn(2, 1) # 随机设定模型参数
  26. t = 0
  27. m = len(X_b)
  28. cost_path_mgd = []
  29. theta_path_mgd = []
  30. for epoch in range(n_epochs):
  31. shuffled_indices = np.random.permutation(m)
  32. X_b_shuffled = X_b[shuffled_indices]
  33. y_shuffled = y[shuffled_indices]
  34. for i in range(0, m, minibatch):
  35. t += 1
  36. xi = X_b_shuffled[i:minibatch + i]
  37. yi = y_shuffled[i:minibatch + i]
  38. gradients = 2 / minibatch * xi.T.dot(xi.dot(theta) - yi)
  39. eta = learning_schedule(t)
  40. theta = theta - eta * gradients
  41. theta_path_mgd.append(theta)
  42. cost = (y - X_b.dot(theta)).T.dot((y - X_b.dot(theta)))[0][0]
  43. cost_path_mgd.append(cost)
  44. plt.plot(range(len(cost_path_mgd)), cost_path_mgd)
  45. plt.show()