预测数据的标准化

在上一节中,我们在用训练出来的模型预测房屋价格之前,还需要先还原W和B的值,这看上去比较麻烦,下面我们来介绍一种正确的推理方法。

既然我们在训练时可以把样本数据标准化,那么在预测时,把预测数据也做相同方式的标准化,不是就可以和训练数据一样进行预测了吗?且慢!这里有一个问题,训练时的样本数据是批量的,至少是成百成千的数量级。但是预测时,一般只有一个或几个数据,如何做标准化?

我们在针对训练数据做标准化时,得到的最重要的数据是训练数据的最小值和最大值,我们只需要把这两个值记录下来,在预测时使用它们对预测数据做标准化,这就相当于把预测数据“混入”训练数据。前提是预测数据的特征值不能超出训练数据的特征值范围,否则有可能影响准确程度。

代码实现

基于这种想法,我们先给SimpleDataReader类增加一个方法NormalizePredicateData(),如下述代码:

  1. class SimpleDataReader(object):
  2. # normalize data by self range and min_value
  3. def NormalizePredicateData(self, X_raw):
  4. X_new = np.zeros(X_raw.shape)
  5. n = X_raw.shape[1]
  6. for i in range(n):
  7. col_i = X_raw[:,i]
  8. X_new[:,i] = (col_i - self.X_norm[i,0]) / self.X_norm[i,1]
  9. return X_new

X_norm数组中的数据,是在训练时从样本数据中得到的最大值最小值,比如表5-11所示的样例。

表5-11 各个特征值的特征保存

最小值 数值范围(最大值减最小值)
特征值1 2.02 21.96-2.02=19.94
特征值2 40 119-40=79

所以,最后X_new就是按照训练样本的规格标准化好的预测标准化数据,然后我们把这个预测标准化数据放入网络中进行预测:

  1. import numpy as np
  2. from HelperClass.NeuralNet import *
  3. if __name__ == '__main__':
  4. # data
  5. reader = SimpleDataReader()
  6. reader.ReadData()
  7. reader.NormalizeX()
  8. # net
  9. params = HyperParameters(eta=0.01, max_epoch=100, batch_size=10, eps = 1e-5)
  10. net = NeuralNet(params, 2, 1)
  11. net.train(reader, checkpoint=0.1)
  12. # inference
  13. x1 = 15
  14. x2 = 93
  15. x = np.array([x1,x2]).reshape(1,2)
  16. x_new = reader.NormalizePredicateData(x)
  17. z = net.inference(x_new)
  18. print("Z=", z)

运行结果

  1. ......
  2. 199 99 380.5942402877278
  3. [[-40.23494571] [399.40443921]] [[244.388824]]
  4. W= [[-40.23494571]
  5. [399.40443921]]
  6. B= [[244.388824]]
  7. Z= [[486.16645199]]

比较一下正规方程的结果:

  1. z= 486.1051325196855

二者非常接近,可以说这种方法的确很方便,把预测数据看作训练数据的一个记录,先做标准化,再做预测,这样就不需要把权重矩阵还原了。

看上去我们已经完美地解决了这个问题,但是且慢,仔细看看loss值,还有w和b的值,都是几十几百的数量级,这和神经网络的概率计算的优点并不吻合,实际上它们的值都应该在[0,1]之间的。

大数量级的数据有另外一个问题,就是它的波动有可能很大。目前我们还没有使用激活函数,一旦网络复杂了,开始使用激活函数时,像486.166这种数据,一旦经过激活函数就会发生梯度饱和的现象,输出值总为1,这样对于后面的网络就没什么意义了,因为输入值都是1。

好吧,看起来问题解决得并不完美,我们看看还能有什么更好的解决方案!

代码位置

原代码位置:ch05, Level5

个人代码:NormalizeLabelData**