title: hook勾函数
subtitle: hook勾函数
date: 2021-06-30
author: NSX
catalog: true
tags:
- hook勾函数
什么是Hook
钩子函数(hook function),顾名思义,可以理解是一个挂钩,是指在执行函数和目标函数之间挂载的函数, 框架开发者给调用方提供一个point -挂载点, 至于挂载什么函数有我们调用方决定, 这样大大提高了灵活性
hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。
##hook实现示例
class Runner:
def __init__(self):
self._hooks = []
def register_hook(self, hook):
self._hooks.append(hook)
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def train(self):
self.a = 10
self.b = 20
self.call_hook('before_train_epoch')
print('Done Epoch!')
self.call_hook('after_train_epoch')
class Hook:
def before_train_epoch(self, runner):
pass
def after_train_epoch(self, runner):
pass
class AddHook(Hook):
def before_train_epoch(self, runner):
print('i am Add')
print(f'Add {runner.a} and {runner.b} equal {runner.a + runner.b}\n')
class MulHook(Hook):
def before_train_epoch(self, runner):
print('i am Mul')
print(f'Add {runner.a} and {runner.b} equal {runner.a * runner.b}\n')
class ExpHook(Hook):
def after_train_epoch(self, runner):
print('i am Exp')
print(f'Exp {runner.a} and {runner.b} equal {runner.a ** runner.b}\n')
class Trainer:
def __init__(self):
self.runner = Runner()
self.runner.register_hook(MulHook())
self.runner.register_hook(ExpHook())
self.runner.register_hook(AddHook())
def run(self):
self.runner.train()
if __name__ == '__main__':
trainer = Trainer()
trainer.run()
hook在开源框架中的应用
keras
在深度学习训练流程中,hook函数体现的淋漓尽致。
一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:
- 开始训练
- 训练一个epoch前
- 训练一个batch前
- 训练一个batch后
- 训练一个epoch后
- 评估验证集
- 结束训练
这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后
我们要保存下训练的模型,在结束训练
时用最好的模型执行下测试集的效果等等。
keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。
@keras_export('keras.callbacks.Callback')
class Callback(object):
"""Abstract base class used to build new callbacks.
Attributes:
params: Dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: Instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch (see method-specific docstrings).
"""
def __init__(self):
self.validation_data = None # pylint: disable=g-missing-from-attributes
self.model = None
# Whether this Callback should only run on the chief worker in a
# Multi-Worker setting.
# TODO(omalleyt): Make this attr public once solution is stable.
self._chief_worker_only = None
self._supports_tf_logs = False
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
@doc_controls.for_subclass_implementers
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.
Arguments:
epoch: Integer, index of epoch.
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@doc_controls.for_subclass_implementers
def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.
Arguments:
epoch: Integer, index of epoch.
logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_train_batch_begin(self, batch, logs=None):
"""Called at the beginning of a training batch in `fit` methods.
Subclasses should override for any actions to run.
Arguments:
batch: Integer, index of batch within the current epoch.
logs: Dict, contains the return value of `model.train_step`. Typically,
the values of the `Model`'s metrics are returned. Example:
`{'loss': 0.2, 'accuracy': 0.7}`.
"""
# For backwards compatibility.
self.on_batch_begin(batch, logs=logs)
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_train_batch_end(self, batch, logs=None):
"""Called at the end of a training batch in `fit` methods.
Subclasses should override for any actions to run.
Arguments:
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
# For backwards compatibility.
self.on_batch_end(batch, logs=logs)
...
总结
本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:
- hook函数是流程中预定义好的一个步骤,没有实现
- 挂载或者注册时, 流程执行就会执行这个钩子函数
- 回调函数和hook函数功能上是一致的
- hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数