title: hook勾函数
subtitle: hook勾函数
date: 2021-06-30
author: NSX
catalog: true
tags:

  • hook勾函数

什么是Hook

2021-06-30-hook勾函数 - 图1

钩子函数(hook function),顾名思义,可以理解是一个挂钩,是指在执行函数和目标函数之间挂载的函数, 框架开发者给调用方提供一个point -挂载点, 至于挂载什么函数有我们调用方决定, 这样大大提高了灵活性

hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。
##hook实现示例

  1. class Runner:
  2. def __init__(self):
  3. self._hooks = []
  4. def register_hook(self, hook):
  5. self._hooks.append(hook)
  6. def call_hook(self, fn_name):
  7. for hook in self._hooks:
  8. getattr(hook, fn_name)(self)
  9. def train(self):
  10. self.a = 10
  11. self.b = 20
  12. self.call_hook('before_train_epoch')
  13. print('Done Epoch!')
  14. self.call_hook('after_train_epoch')
  15. class Hook:
  16. def before_train_epoch(self, runner):
  17. pass
  18. def after_train_epoch(self, runner):
  19. pass
  20. class AddHook(Hook):
  21. def before_train_epoch(self, runner):
  22. print('i am Add')
  23. print(f'Add {runner.a} and {runner.b} equal {runner.a + runner.b}\n')
  24. class MulHook(Hook):
  25. def before_train_epoch(self, runner):
  26. print('i am Mul')
  27. print(f'Add {runner.a} and {runner.b} equal {runner.a * runner.b}\n')
  28. class ExpHook(Hook):
  29. def after_train_epoch(self, runner):
  30. print('i am Exp')
  31. print(f'Exp {runner.a} and {runner.b} equal {runner.a ** runner.b}\n')
  32. class Trainer:
  33. def __init__(self):
  34. self.runner = Runner()
  35. self.runner.register_hook(MulHook())
  36. self.runner.register_hook(ExpHook())
  37. self.runner.register_hook(AddHook())
  38. def run(self):
  39. self.runner.train()
  40. if __name__ == '__main__':
  41. trainer = Trainer()
  42. trainer.run()

hook在开源框架中的应用

keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练
  • 训练一个epoch前
  • 训练一个batch前
  • 训练一个batch后
  • 训练一个epoch后
  • 评估验证集
  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。

  1. @keras_export('keras.callbacks.Callback')
  2. class Callback(object):
  3. """Abstract base class used to build new callbacks.
  4. Attributes:
  5. params: Dict. Training parameters
  6. (eg. verbosity, batch size, number of epochs...).
  7. model: Instance of `keras.models.Model`.
  8. Reference of the model being trained.
  9. The `logs` dictionary that callback methods
  10. take as argument will contain keys for quantities relevant to
  11. the current batch or epoch (see method-specific docstrings).
  12. """
  13. def __init__(self):
  14. self.validation_data = None # pylint: disable=g-missing-from-attributes
  15. self.model = None
  16. # Whether this Callback should only run on the chief worker in a
  17. # Multi-Worker setting.
  18. # TODO(omalleyt): Make this attr public once solution is stable.
  19. self._chief_worker_only = None
  20. self._supports_tf_logs = False
  21. def set_params(self, params):
  22. self.params = params
  23. def set_model(self, model):
  24. self.model = model
  25. @doc_controls.for_subclass_implementers
  26. @generic_utils.default
  27. def on_batch_begin(self, batch, logs=None):
  28. """A backwards compatibility alias for `on_train_batch_begin`."""
  29. @doc_controls.for_subclass_implementers
  30. @generic_utils.default
  31. def on_batch_end(self, batch, logs=None):
  32. """A backwards compatibility alias for `on_train_batch_end`."""
  33. @doc_controls.for_subclass_implementers
  34. def on_epoch_begin(self, epoch, logs=None):
  35. """Called at the start of an epoch.
  36. Subclasses should override for any actions to run. This function should only
  37. be called during TRAIN mode.
  38. Arguments:
  39. epoch: Integer, index of epoch.
  40. logs: Dict. Currently no data is passed to this argument for this method
  41. but that may change in the future.
  42. """
  43. @doc_controls.for_subclass_implementers
  44. def on_epoch_end(self, epoch, logs=None):
  45. """Called at the end of an epoch.
  46. Subclasses should override for any actions to run. This function should only
  47. be called during TRAIN mode.
  48. Arguments:
  49. epoch: Integer, index of epoch.
  50. logs: Dict, metric results for this training epoch, and for the
  51. validation epoch if validation is performed. Validation result keys
  52. are prefixed with `val_`.
  53. """
  54. @doc_controls.for_subclass_implementers
  55. @generic_utils.default
  56. def on_train_batch_begin(self, batch, logs=None):
  57. """Called at the beginning of a training batch in `fit` methods.
  58. Subclasses should override for any actions to run.
  59. Arguments:
  60. batch: Integer, index of batch within the current epoch.
  61. logs: Dict, contains the return value of `model.train_step`. Typically,
  62. the values of the `Model`'s metrics are returned. Example:
  63. `{'loss': 0.2, 'accuracy': 0.7}`.
  64. """
  65. # For backwards compatibility.
  66. self.on_batch_begin(batch, logs=logs)
  67. @doc_controls.for_subclass_implementers
  68. @generic_utils.default
  69. def on_train_batch_end(self, batch, logs=None):
  70. """Called at the end of a training batch in `fit` methods.
  71. Subclasses should override for any actions to run.
  72. Arguments:
  73. batch: Integer, index of batch within the current epoch.
  74. logs: Dict. Aggregated metric results up until this batch.
  75. """
  76. # For backwards compatibility.
  77. self.on_batch_end(batch, logs=logs)
  78. ...

总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:

  • hook函数是流程中预定义好的一个步骤,没有实现
  • 挂载或者注册时, 流程执行就会执行这个钩子函数
  • 回调函数和hook函数功能上是一致的
  • hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数

参考

https://blog.csdn.net/pdcfighting/article/details/111243722