背景

来自Coursera上的斯坦福大学机器学习课程:现有47个房子的面积和价格,需要建立一个模型对新的房价进行预测。

  • 输入数据只有一维,亦房子的面积。
  • 目标数据也只有一维,亦即房子的价格。
  • 需要作的,就是根据已知的房子的面积和价格的关系进行机器学习

步骤

  1. 获取与处理数据
  2. 选择与训练模型
  3. 评估与可视化结果

代码

  1. # 导入需要用到的库
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. '''
  5. 第一步:获取与处理数据
  6. '''
  7. # 定义存储输入数据(x)和目标数据(y)的数组
  8. x, y = [], []
  9. # 遍历数据集,变量sample对应的正是一个个样本
  10. for sample in open("D:/Study/2022/06/prices.txt", "r"):
  11. # 由于数据是用逗号隔开的,所以调用Python的split方法并将逗号作为参数传入
  12. _x, _y = sample.split(",")
  13. # 将字符串数据转化为浮点数
  14. x.append(float(_x))
  15. y.append(float(_y))
  16. # 读取完数据后,将他们转化为Numpy数组以方便进一步的处理
  17. x, y = np.array(x), np.array(y)
  18. # 标准化
  19. x = (x - x.mean()) / x.std()
  20. # 将原始数据以散点图的形式画出
  21. plt.figure()
  22. plt.scatter(x, y, c = "g", s = 6)
  23. plt.show()
  24. '''
  25. 第二步:选择与训练模型
  26. '''
  27. # 此处选择线性回归的多项式模型,采用常见的平方损失函数,借助Numpy库实现
  28. # 在(-2,4)这个区间上取100个点作为画图的基础
  29. x0 = np.linspace(-2, 4, 100)
  30. # 利用Numpy的函数定义训练并返回多项式回归模型的函数
  31. # deg参数代表着模型参数中的n,亦即模型中多项式的次数
  32. # 返回的模型能够根据输入的x(默认是x0),返回相对应的预测的y
  33. def get_model(deg):
  34. return lambda input_x = x0: np.polyval(np.polyfit(x, y, deg), input_x)
  35. # polyfit(x, y, deg):该函数返回使得平方损失函数最小的参数p(多项式f的各项系数),该函数就是模型的训练函数。
  36. # polyval(p, x):根据多项式的各项系数p和多项式x的值,返回多项式的值y
  37. '''
  38. 第三步:评估与可视化结果
  39. '''
  40. # 根据参数n,输入的x,y返回相对应的损失
  41. def get_cost(deg, input_x, input_y):
  42. return 0.5 * ((get_model(deg)(input_x) - input_y) ** 2).sum()
  43. # 定义测试参数集并根据它进行各种实验
  44. test_set = (1, 4, 10)
  45. for d in test_set:
  46. # 输出相应的损失
  47. print(get_cost(d, x, y))
  48. # n=1:96732238800.35292
  49. # n=4:94112406641.67741
  50. # n=10:75874846680.09283
  51. # 通过损失值看出n=10优于n=4,n=1最差,通过画图直观了解是否出现过拟合
  52. # 画出相应的图像
  53. plt.scatter(x, y, c = "g", s = 20)
  54. for d in test_set:
  55. plt.plot(x0, get_model(d)(), label = "degree = {}".format(d))
  56. # 将横轴、纵轴的范围分别限制在(-2,4)、(10^5, 8x10^5)
  57. plt.xlim(-2, 4)
  58. plt.ylim(1e5, 8e5)
  59. # 调用legend方法使曲线对应的label正确显示
  60. plt.legend()
  61. plt.show()
  62. # 通过画图看出n=4开始出现过拟合现象,n=10模型已经非常不合理

结果

image.png
image.png

附件

prices.txt