Pytorch中的注册机制
import torchfrom torch import nnfrom mmcv.cnn import constant_init# hook 函数,其三个参数不能修改(参数名随意),本质上是 PyTorch 内部回调函数# module 本身对象# input 该 module forward 前输入# output 该 module forward 后输出def forward_hook_fn(module, input, output):print('weight', module.weight.data)print('bias', module.bias.data)print('input', input)print('output', output)class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.fc = nn.Linear(3, 1)self.fc.register_forward_hook(forward_hook_fn)constant_init(self.fc, 1)def forward(self, x):o = self.fc(x)return oif __name__ == '__main__':model = Model()x = torch.Tensor([[0.0, 1.0, 2.0]])y = model(x)
# 输出weight:tensor([[1., 1., 1.]])bias: tensor([0.])input: (tensor([[0., 1., 2.]]),)output:tensor([[3.]], grad_fn=<AddmmBackward>)
MMCV中HOOK基类
mmcv/runner/hooks/hook.py定义的基类
# Copyright (c) OpenMMLab. All rights reserved.from mmcv.utils import Registry, is_method_overriddenHOOKS = Registry('hook')class Hook:stages = ('before_run', 'before_train_epoch', 'before_train_iter','after_train_iter', 'after_train_epoch', 'before_val_epoch','before_val_iter', 'after_val_iter', 'after_val_epoch','after_run')def before_run(self, runner):passdef after_run(self, runner):passdef before_epoch(self, runner):passdef after_epoch(self, runner):passdef before_iter(self, runner):passdef after_iter(self, runner):passdef before_train_epoch(self, runner):self.before_epoch(runner)def before_val_epoch(self, runner):self.before_epoch(runner)def after_train_epoch(self, runner):self.after_epoch(runner)def after_val_epoch(self, runner):self.after_epoch(runner)def before_train_iter(self, runner):self.before_iter(runner)def before_val_iter(self, runner):self.before_iter(runner)def after_train_iter(self, runner):self.after_iter(runner)def after_val_iter(self, runner):self.after_iter(runner)def every_n_epochs(self, runner, n):return (runner.epoch + 1) % n == 0 if n > 0 else Falsedef every_n_inner_iters(self, runner, n):return (runner.inner_iter + 1) % n == 0 if n > 0 else Falsedef every_n_iters(self, runner, n):return (runner.iter + 1) % n == 0 if n > 0 else Falsedef end_of_epoch(self, runner):return runner.inner_iter + 1 == len(runner.data_loader)def is_last_epoch(self, runner):return runner.epoch + 1 == runner._max_epochsdef is_last_iter(self, runner):return runner.iter + 1 == runner._max_itersdef get_triggered_stages(self):trigger_stages = set()for stage in Hook.stages:if is_method_overridden(stage, Hook, self):trigger_stages.add(stage)# some methods will be triggered in multi stages# use this dict to map method to stages.method_stages_map = {'before_epoch': ['before_train_epoch', 'before_val_epoch'],'after_epoch': ['after_train_epoch', 'after_val_epoch'],'before_iter': ['before_train_iter', 'before_val_iter'],'after_iter': ['after_train_iter', 'after_val_iter'],}for method, map_stages in method_stages_map.items():if is_method_overridden(method, Hook, self):trigger_stages.update(map_stages)return [stage for stage in Hook.stages if stage in trigger_stages]
Hook分类和用法
默认HOOK
- CheckPointHook(保存ckpt)
- LrUpdaterHook(学习率调度)
- OptimizerHook(方向传播+参数更新)
- Fp16OptimizerHook(混合精度训练)
- TextLoggerHook(日志打印)
- IterTimerHook(迭代一次时间统计)
- 分布式相关(DistSamplerSeedHook)确保shuffle生效
- MomentumUpdaterHook(动量更新,用于3d目标检测)
定制HOOK
- EMAHook(模型ema)
- TensorboardLoggerHook(tensorboard简单可视化)
- EmptyCacheHook(Pytorch cuda缓存清除)
- SyncBufferHook(同步Buff)
- Closeure(简单函数快速注册)
- 各大框架自定义hook(MMDetection中EvalHook和DistEvaHook)
自定义HOOK(UploadHook)

# Copyright (c) OpenMMLab. All rights reserved.import torchfrom mmcv.runner.hooks import HOOKS, Hook@HOOKS.register_module()class UploadHook(Hook):"""Check invalid loss hook.This hook will regularly check whether the loss is validduring training.Args:interval (int): Checking interval (every k iterations).Default: 50."""def __init__(self, interval=1):self.interval = intervaldef before_run(self, runner):print('upload before_run')passdef after_run(self, runner):print('upload after_run')passdef before_epoch(self, runner):print('upload before_epoch')passdef after_epoch(self, runner):print('upload after_epoch')passdef before_iter(self, runner):print('upload before_iter')passdef after_iter(self, runner):print('upload after_iter')passdef before_train_epoch(self, runner):print('upload before_train_epoch')def before_val_epoch(self, runner):print('upload before_val_epoch')def after_train_epoch(self, runner):print('upload after_train_epoch')def after_val_epoch(self, runner):print('upload after_val_epoch')def before_train_iter(self, runner):print('upload before_train_iter')def before_val_iter(self, runner):print('upload before_val_iter')def after_train_iter(self, runner):print('upload loss after_train_iter')if self.every_n_iters(runner, self.interval):assert torch.isfinite(runner.outputs['loss']), \runner.logger.info('loss become infinite or NaN!')def after_val_iter(self, runner):print('upload after_val_iter')def after_train_iter(self, runner):print('upload after_train_iter')
Hookd的运行顺序
If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks.
先运行default hook,再运行custom hook
参考
目标检测(MMdetection)-HOOK机制
MMCV 核心组件分析(六): Hook
【官方文档】Customize hooks
