本节将介绍在训练过程中如何更好地访问并控制模型内部过程的方法。
使用 model.fit()
或 model.fit_generator()
在一个大型数据集上启动数十轮的训练,有点类似于扔一架纸飞机,一开始给它一点推力,之后你便再也无法控制其飞行轨迹或着陆点。如果想要避免不好的结果(并避免浪费纸飞机),更聪明的做法是不用纸飞机,而是用一架无人机,它可以感知其环境,将数据发回给操纵者,并且能够基于当前状态自主航行。
我们下面要介绍的技术,可以让 model.fit()
的调用从纸飞机变为智能的自主无人机,可以自我反省并动态地采取行动。
训练过程中将回调函数作用于模型
训练模型时,很多事情一开始都无法预测。尤其是你不知道需要多少轮才能得到最佳验证 损失。前面所有例子都采用这样一种策略:训练足够多的轮次,这时模型已经开始过拟合,根 据这第一次运行来确定训练所需要的正确轮数,然后使用这个最佳轮数从头开始再启动一次新 的训练。当然,这种方法很浪费。
处理这个问题的更好方法是,当观测到验证损失不再改善时就停止训练。这可以使用 Keras 回调函数来实现。
回调函数(callback)是在调用 fit 时传入模型的一个对象(即实现特定方法 的类实例),它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态与性能的 所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或改变模型的状态。
回调函数的一些用法示例:
- 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前权重。
- 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训 练过程中得到的最佳模型)。
- 在训练过程中动态调节某些参数值:比如优化器的学习率。
- 在训练过程中记录训练指标和验证指标,或将模型学到的表示可视化(这些表示也在不断更新):你熟悉的 Keras 进度条就是一个回调函数!
keras.callbacks 模块包含许多内置的回调函数,下面列出了其中一些,但还有很多没有列出来:
- keras.callbacks.ModelCheckpoint
- keras.callbacks.EarlyStopping
- keras.callbacks.LearningRateScheduler
- keras.callbacks.ReduceLROnPlateau
- keras.callbacks.CSVLogger
Keras 回调函数介绍
1. ModelCheckpoint 与 EarlyStopping 回调函数
如果监控的目标指标在设定的轮数内不再改善,可以用 EarlyStopping
回调函数来中断训练。
这个回调函数通常与 ModelCheckpoint
结合使用,后者可以在训练过程中持续 不断地保存模型(你也可以选择只保存目前的最佳模型,即一轮结束后具有最佳性能的模型)。
import keras
callbacks_list = [
keras.callbacks.EarlyStopping( # 如果不再改善就终止训练
monitor='acc', # 监控模型的验证精度
patience=1, # 如果精度在多于一轮的时间(即两轮)内不再改善,中断训练
),
keras.callbacks.ModelCheckpoint( # 在每轮过后保存当前权重
filepath='my_model.h5', # 目标模型文件的保存路径
monitor='val_loss',
save_best_only=True, # 这两个参数的含义是,如果 val_loss 没有改善,那么不需要覆盖模型文件。这就可以始终保存在训练过程中见到的最佳模型
)
]
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['acc']) # 你监控精度,所以它应该是模型指标的一部分
model.fit(x, y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val, y_val))
2. ReduceLROnPlateau 回调函数
如果验证损失不再改善,你可以使用这个回调函数来降低学习率。在训练过程中如果出现了损失平台(loss plateau),那么增大或减小学习率都是跳出局部最小值的有效策略。
callbacks_list = [
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', # 监控模型的验证损失
factor=0.1, # 触发时将学习率除以 10
patience=10, # 如果验证损失在 10 轮内都没有改善,那么就触发这个回调函数
)
]
model.fit(x, y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val, y_val)) # 注意,因为回调函数要监控验证损失,所以你需要在调用 fit 时传入 validation_data(验证数据)
3. 编写你自己的回调函数
如果你需要在训练过程中采取特定行动,而这项行动又没有包含在内置回调函数中,那么可以编写你自己的回调函数。
回调函数的实现方式是创建 keras.callbacks.Callback
类的子类。然后你可以实现下面这些方法(从名称中即可看出这些方法的作用),它们分别在训练过程中的不同时间点被调用:
- on_epoch_begin
- on_epoch_end
- on_batch_begin
- on_batch_end
- on_train_begin
- on_train_end
具体实现时再查找其他资料即可。