title: hook勾函数
subtitle: hook勾函数
date: 2021-06-30
author: NSX
catalog: true
tags:
- hook勾函数
什么是Hook

钩子函数(hook function),顾名思义,可以理解是一个挂钩,是指在执行函数和目标函数之间挂载的函数, 框架开发者给调用方提供一个point -挂载点, 至于挂载什么函数有我们调用方决定, 这样大大提高了灵活性
hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。
##hook实现示例
class Runner:def __init__(self):self._hooks = []def register_hook(self, hook):self._hooks.append(hook)def call_hook(self, fn_name):for hook in self._hooks:getattr(hook, fn_name)(self)def train(self):self.a = 10self.b = 20self.call_hook('before_train_epoch')print('Done Epoch!')self.call_hook('after_train_epoch')class Hook:def before_train_epoch(self, runner):passdef after_train_epoch(self, runner):passclass AddHook(Hook):def before_train_epoch(self, runner):print('i am Add')print(f'Add {runner.a} and {runner.b} equal {runner.a + runner.b}\n')class MulHook(Hook):def before_train_epoch(self, runner):print('i am Mul')print(f'Add {runner.a} and {runner.b} equal {runner.a * runner.b}\n')class ExpHook(Hook):def after_train_epoch(self, runner):print('i am Exp')print(f'Exp {runner.a} and {runner.b} equal {runner.a ** runner.b}\n')class Trainer:def __init__(self):self.runner = Runner()self.runner.register_hook(MulHook())self.runner.register_hook(ExpHook())self.runner.register_hook(AddHook())def run(self):self.runner.train()if __name__ == '__main__':trainer = Trainer()trainer.run()
hook在开源框架中的应用
keras
在深度学习训练流程中,hook函数体现的淋漓尽致。
一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:
- 开始训练
- 训练一个epoch前
- 训练一个batch前
- 训练一个batch后
- 训练一个epoch后
- 评估验证集
- 结束训练
这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。
keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。
@keras_export('keras.callbacks.Callback')class Callback(object):"""Abstract base class used to build new callbacks.Attributes:params: Dict. Training parameters(eg. verbosity, batch size, number of epochs...).model: Instance of `keras.models.Model`.Reference of the model being trained.The `logs` dictionary that callback methodstake as argument will contain keys for quantities relevant tothe current batch or epoch (see method-specific docstrings)."""def __init__(self):self.validation_data = None # pylint: disable=g-missing-from-attributesself.model = None# Whether this Callback should only run on the chief worker in a# Multi-Worker setting.# TODO(omalleyt): Make this attr public once solution is stable.self._chief_worker_only = Noneself._supports_tf_logs = Falsedef set_params(self, params):self.params = paramsdef set_model(self, model):self.model = model@doc_controls.for_subclass_implementers@generic_utils.defaultdef on_batch_begin(self, batch, logs=None):"""A backwards compatibility alias for `on_train_batch_begin`."""@doc_controls.for_subclass_implementers@generic_utils.defaultdef on_batch_end(self, batch, logs=None):"""A backwards compatibility alias for `on_train_batch_end`."""@doc_controls.for_subclass_implementersdef on_epoch_begin(self, epoch, logs=None):"""Called at the start of an epoch.Subclasses should override for any actions to run. This function should onlybe called during TRAIN mode.Arguments:epoch: Integer, index of epoch.logs: Dict. Currently no data is passed to this argument for this methodbut that may change in the future."""@doc_controls.for_subclass_implementersdef on_epoch_end(self, epoch, logs=None):"""Called at the end of an epoch.Subclasses should override for any actions to run. This function should onlybe called during TRAIN mode.Arguments:epoch: Integer, index of epoch.logs: Dict, metric results for this training epoch, and for thevalidation epoch if validation is performed. Validation result keysare prefixed with `val_`."""@doc_controls.for_subclass_implementers@generic_utils.defaultdef on_train_batch_begin(self, batch, logs=None):"""Called at the beginning of a training batch in `fit` methods.Subclasses should override for any actions to run.Arguments:batch: Integer, index of batch within the current epoch.logs: Dict, contains the return value of `model.train_step`. Typically,the values of the `Model`'s metrics are returned. Example:`{'loss': 0.2, 'accuracy': 0.7}`."""# For backwards compatibility.self.on_batch_begin(batch, logs=logs)@doc_controls.for_subclass_implementers@generic_utils.defaultdef on_train_batch_end(self, batch, logs=None):"""Called at the end of a training batch in `fit` methods.Subclasses should override for any actions to run.Arguments:batch: Integer, index of batch within the current epoch.logs: Dict. Aggregated metric results up until this batch."""# For backwards compatibility.self.on_batch_end(batch, logs=logs)...
总结
本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:
- hook函数是流程中预定义好的一个步骤,没有实现
- 挂载或者注册时, 流程执行就会执行这个钩子函数
- 回调函数和hook函数功能上是一致的
- hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数
