Pytorch中的注册机制

  1. import torch
  2. from torch import nn
  3. from mmcv.cnn import constant_init
  4. # hook 函数,其三个参数不能修改(参数名随意),本质上是 PyTorch 内部回调函数
  5. # module 本身对象
  6. # input 该 module forward 前输入
  7. # output 该 module forward 后输出
  8. def forward_hook_fn(module, input, output):
  9. print('weight', module.weight.data)
  10. print('bias', module.bias.data)
  11. print('input', input)
  12. print('output', output)
  13. class Model(nn.Module):
  14. def __init__(self):
  15. super(Model, self).__init__()
  16. self.fc = nn.Linear(3, 1)
  17. self.fc.register_forward_hook(forward_hook_fn)
  18. constant_init(self.fc, 1)
  19. def forward(self, x):
  20. o = self.fc(x)
  21. return o
  22. if __name__ == '__main__':
  23. model = Model()
  24. x = torch.Tensor([[0.0, 1.0, 2.0]])
  25. y = model(x)
  1. # 输出
  2. weighttensor([[1., 1., 1.]])
  3. bias: tensor([0.])
  4. input: (tensor([[0., 1., 2.]]),)
  5. output:tensor([[3.]], grad_fn=<AddmmBackward>)

MMCV中HOOK基类

mmcv/runner/hooks/hook.py定义的基类

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.utils import Registry, is_method_overridden
  3. HOOKS = Registry('hook')
  4. class Hook:
  5. stages = ('before_run', 'before_train_epoch', 'before_train_iter',
  6. 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
  7. 'before_val_iter', 'after_val_iter', 'after_val_epoch',
  8. 'after_run')
  9. def before_run(self, runner):
  10. pass
  11. def after_run(self, runner):
  12. pass
  13. def before_epoch(self, runner):
  14. pass
  15. def after_epoch(self, runner):
  16. pass
  17. def before_iter(self, runner):
  18. pass
  19. def after_iter(self, runner):
  20. pass
  21. def before_train_epoch(self, runner):
  22. self.before_epoch(runner)
  23. def before_val_epoch(self, runner):
  24. self.before_epoch(runner)
  25. def after_train_epoch(self, runner):
  26. self.after_epoch(runner)
  27. def after_val_epoch(self, runner):
  28. self.after_epoch(runner)
  29. def before_train_iter(self, runner):
  30. self.before_iter(runner)
  31. def before_val_iter(self, runner):
  32. self.before_iter(runner)
  33. def after_train_iter(self, runner):
  34. self.after_iter(runner)
  35. def after_val_iter(self, runner):
  36. self.after_iter(runner)
  37. def every_n_epochs(self, runner, n):
  38. return (runner.epoch + 1) % n == 0 if n > 0 else False
  39. def every_n_inner_iters(self, runner, n):
  40. return (runner.inner_iter + 1) % n == 0 if n > 0 else False
  41. def every_n_iters(self, runner, n):
  42. return (runner.iter + 1) % n == 0 if n > 0 else False
  43. def end_of_epoch(self, runner):
  44. return runner.inner_iter + 1 == len(runner.data_loader)
  45. def is_last_epoch(self, runner):
  46. return runner.epoch + 1 == runner._max_epochs
  47. def is_last_iter(self, runner):
  48. return runner.iter + 1 == runner._max_iters
  49. def get_triggered_stages(self):
  50. trigger_stages = set()
  51. for stage in Hook.stages:
  52. if is_method_overridden(stage, Hook, self):
  53. trigger_stages.add(stage)
  54. # some methods will be triggered in multi stages
  55. # use this dict to map method to stages.
  56. method_stages_map = {
  57. 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
  58. 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
  59. 'before_iter': ['before_train_iter', 'before_val_iter'],
  60. 'after_iter': ['after_train_iter', 'after_val_iter'],
  61. }
  62. for method, map_stages in method_stages_map.items():
  63. if is_method_overridden(method, Hook, self):
  64. trigger_stages.update(map_stages)
  65. return [stage for stage in Hook.stages if stage in trigger_stages]

Hook分类和用法

默认HOOK

  • CheckPointHook(保存ckpt)
  • LrUpdaterHook(学习率调度)
  • OptimizerHook(方向传播+参数更新)
  • Fp16OptimizerHook(混合精度训练)
  • TextLoggerHook(日志打印)
  • IterTimerHook(迭代一次时间统计)
  • 分布式相关(DistSamplerSeedHook)确保shuffle生效
  • MomentumUpdaterHook(动量更新,用于3d目标检测)

定制HOOK

  • EMAHook(模型ema)
  • TensorboardLoggerHook(tensorboard简单可视化)
  • EmptyCacheHook(Pytorch cuda缓存清除)
  • SyncBufferHook(同步Buff)
  • Closeure(简单函数快速注册)
  • 各大框架自定义hook(MMDetection中EvalHook和DistEvaHook)

自定义HOOK(UploadHook)

image.png

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmcv.runner.hooks import HOOKS, Hook
  4. @HOOKS.register_module()
  5. class UploadHook(Hook):
  6. """Check invalid loss hook.
  7. This hook will regularly check whether the loss is valid
  8. during training.
  9. Args:
  10. interval (int): Checking interval (every k iterations).
  11. Default: 50.
  12. """
  13. def __init__(self, interval=1):
  14. self.interval = interval
  15. def before_run(self, runner):
  16. print('upload before_run')
  17. pass
  18. def after_run(self, runner):
  19. print('upload after_run')
  20. pass
  21. def before_epoch(self, runner):
  22. print('upload before_epoch')
  23. pass
  24. def after_epoch(self, runner):
  25. print('upload after_epoch')
  26. pass
  27. def before_iter(self, runner):
  28. print('upload before_iter')
  29. pass
  30. def after_iter(self, runner):
  31. print('upload after_iter')
  32. pass
  33. def before_train_epoch(self, runner):
  34. print('upload before_train_epoch')
  35. def before_val_epoch(self, runner):
  36. print('upload before_val_epoch')
  37. def after_train_epoch(self, runner):
  38. print('upload after_train_epoch')
  39. def after_val_epoch(self, runner):
  40. print('upload after_val_epoch')
  41. def before_train_iter(self, runner):
  42. print('upload before_train_iter')
  43. def before_val_iter(self, runner):
  44. print('upload before_val_iter')
  45. def after_train_iter(self, runner):
  46. print('upload loss after_train_iter')
  47. if self.every_n_iters(runner, self.interval):
  48. assert torch.isfinite(runner.outputs['loss']), \
  49. runner.logger.info('loss become infinite or NaN!')
  50. def after_val_iter(self, runner):
  51. print('upload after_val_iter')
  52. def after_train_iter(self, runner):
  53. print('upload after_train_iter')

Hookd的运行顺序

If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks.

先运行default hook,再运行custom hook

参考

目标检测(MMdetection)-HOOK机制
MMCV 核心组件分析(六): Hook
【官方文档】Customize hooks