为什么会有overfitting这样的状况呢,为什么有可能training的loss小,testing的loss大呢,这边就举一个极端的例子。
这是我们的训练集,假设根据这些训练集,有一个很废的function说,如果今天x当做输入的时候,我们就去比对这个x,有没有出现在训练集里面,如果x有出现在训练集里面,就把它对应的ŷ当做输出,如果x没有出现在训练集里面,就输出一个随机的值。
那你可以想像这个function啥事也没有干,虽然它是一个一无是处的function,它在training的data上,它的loss可是0呢。可是在testing data上面,它的loss会变得很大,因为它其实什么都没有预测。
在一般的状况下,也有可能发生类似的事情。
举例来说,假设我们输入的feature叫做x,我们输出的level叫做y,那x跟y都是一维的。
x跟y之间的关系,是这个二次的曲线,这个曲线我们刻意用虚线来表示,因为我们通常没有办法直接观察到这条曲线,我们真正可以观察到的是我们的训练集,训练集就是从这条曲线上面随机sample出来的几个点。
今天的模型它的能力非常的强,你只给它这三个点,它会知道在这三个点上面我们要让loss低,所以你的model它的这个曲线会通过这三个点。但是其它没有训练集做为限制的地方,它就会有freestyle,因为它的弹性很大,所以你的model可以变成各式各样的function,可以产生各式各样奇怪的结果。
这个时候,如果你再丢进你的testing data,用蓝色的这些点,找出一个function以后,你测试在橘色的这些点上,不一定会好。如果你的model它的自由度很大的话,它可以产生非常奇怪的曲线,导致训练集上的结果好,但是测试集上的loss很大。
解决Overfitting
- 第一个方向是,也许这个方向往往是最有效的方向,是增加你的训练集数量。
但是你在作业里面,你是不能够使用这一招的,因为我们并不希望大家浪费时间,来收集资料。
那你可以做什么呢,你可以做data augmentation,这个方法并不算是使用了额外的资料。
Data augmentation就是用一些你对于这个问题的理解,自己创造出新的资料。例如将图片左右翻转,或者是把它其中一块截出来放大等等,左右翻转后你的资料就变成两倍了。
但是也不能够随便乱做,你就很少看到有人把影像上下颠倒当作augmentation,因为这些图片都是合理的图片,你把一张照片左右翻转,并不会影响到里面是什么样的东西,但你把它颠倒那就很奇怪了,这可能不是一个训练集里面,可能不是真实世界会出现的影像。那如果你给机器看这种奇怪的影像的话,它可能就会学到奇怪的东西。
- 另外一个解法就是不要让你的模型有那么大的弹性,给它一些限制。
举例来说,假设我们直接限制说我们的model,我们猜测出x跟y背后的关系,其实就是一条二次曲线,只是我们不明确知道这条二次曲线里面的每一个参数长什么样。
那现在假设我们已经知道模型就是二次曲线,在选择function的时候就会有很大的限制。所以虽然只给了三个点,但是因为我们能选择的function有限,你可能就会正好选到跟真正的distribution比较接近的function,然后在测试集上得到比较好的结果。
有哪些方法可以给model制造限制呢,举例来说。
- 比较少的参数。如果是deep learning的话,就给它比较少的神经元的数目,或者是可以让model共用参数,你可以让一些参数有一样的数值。
我们之前讲的network的架构,叫做fully-connected network,它是一个比较有弹性的架构,而CNN是一个比较有限制的架构。CNN厉害的地方就是,它是针对影像的特性来限制模型的弹性,所以在影像上反而会做得比较好。
- 比较少的features。
- Early stopping。
- Regularization规范化。
- Dropout。
但是我们也不要给太多的限制。
假设我们现在给模型更大的限制说,一定是Linear的Model,一定是写成y=a+bx,那你的model它能够产生的function就一定是一条直线。
今天给三个点,没有任何一条直线,可以同时通过这三个点,但是你只能找到一条直线,它们的距离是比较近的。这个时候你的模型的限制就太大了,你在测试集上就不会得到好的结果。这个不是overfitting,因为你又回到了model bias的问题。
怎么真的衡量一个模型的弹性,复杂的程度有多大?
所谓比较复杂就是,它可以包含的function比较多,它的参数比较多,这个就是一个比较复杂的model。
但随着model越来越复杂,Training的loss可以越来越低,testing的loss会跟着下降,但当复杂到一定程度时,Testing的loss就会突然暴增了(overfitting)。
Cross Validation
可以把Training的资料分成两半,一部分叫作Training Set,一部分是Validation Set。
先在Training Set上训练出来模型,然后在Validation Set上面去衡量它们的分数,根据Validation Set上面的分数,去挑选结果。
N-fold Cross Validation
如果担心Validation Set分的不好,那你可以用N-fold Cross Validation。
N-fold Cross Validation就是你先把你的训练集切成N等份。在这个例子里面我们切成三等份,切完以后,你拿其中一份当作Validation Set,另外两份当Training Set,然后这件事情你要重复三次。
把每一个模型在这三种状况的结果,都平均起来,再看看谁的结果最好。