Pytorch中的注册机制
import torch
from torch import nn
from mmcv.cnn import constant_init
# hook 函数,其三个参数不能修改(参数名随意),本质上是 PyTorch 内部回调函数
# module 本身对象
# input 该 module forward 前输入
# output 该 module forward 后输出
def forward_hook_fn(module, input, output):
print('weight', module.weight.data)
print('bias', module.bias.data)
print('input', input)
print('output', output)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(3, 1)
self.fc.register_forward_hook(forward_hook_fn)
constant_init(self.fc, 1)
def forward(self, x):
o = self.fc(x)
return o
if __name__ == '__main__':
model = Model()
x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)
# 输出
weight:tensor([[1., 1., 1.]])
bias: tensor([0.])
input: (tensor([[0., 1., 2.]]),)
output:tensor([[3.]], grad_fn=<AddmmBackward>)
MMCV中HOOK基类
mmcv/runner/hooks/hook.py定义的基类
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, is_method_overridden
HOOKS = Registry('hook')
class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
def get_triggered_stages(self):
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)
# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
'before_iter': ['before_train_iter', 'before_val_iter'],
'after_iter': ['after_train_iter', 'after_val_iter'],
}
for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)
return [stage for stage in Hook.stages if stage in trigger_stages]
Hook分类和用法
默认HOOK
- CheckPointHook(保存ckpt)
- LrUpdaterHook(学习率调度)
- OptimizerHook(方向传播+参数更新)
- Fp16OptimizerHook(混合精度训练)
- TextLoggerHook(日志打印)
- IterTimerHook(迭代一次时间统计)
- 分布式相关(DistSamplerSeedHook)确保shuffle生效
- MomentumUpdaterHook(动量更新,用于3d目标检测)
定制HOOK
- EMAHook(模型ema)
- TensorboardLoggerHook(tensorboard简单可视化)
- EmptyCacheHook(Pytorch cuda缓存清除)
- SyncBufferHook(同步Buff)
- Closeure(简单函数快速注册)
- 各大框架自定义hook(MMDetection中EvalHook和DistEvaHook)
自定义HOOK(UploadHook)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner.hooks import HOOKS, Hook
@HOOKS.register_module()
class UploadHook(Hook):
"""Check invalid loss hook.
This hook will regularly check whether the loss is valid
during training.
Args:
interval (int): Checking interval (every k iterations).
Default: 50.
"""
def __init__(self, interval=1):
self.interval = interval
def before_run(self, runner):
print('upload before_run')
pass
def after_run(self, runner):
print('upload after_run')
pass
def before_epoch(self, runner):
print('upload before_epoch')
pass
def after_epoch(self, runner):
print('upload after_epoch')
pass
def before_iter(self, runner):
print('upload before_iter')
pass
def after_iter(self, runner):
print('upload after_iter')
pass
def before_train_epoch(self, runner):
print('upload before_train_epoch')
def before_val_epoch(self, runner):
print('upload before_val_epoch')
def after_train_epoch(self, runner):
print('upload after_train_epoch')
def after_val_epoch(self, runner):
print('upload after_val_epoch')
def before_train_iter(self, runner):
print('upload before_train_iter')
def before_val_iter(self, runner):
print('upload before_val_iter')
def after_train_iter(self, runner):
print('upload loss after_train_iter')
if self.every_n_iters(runner, self.interval):
assert torch.isfinite(runner.outputs['loss']), \
runner.logger.info('loss become infinite or NaN!')
def after_val_iter(self, runner):
print('upload after_val_iter')
def after_train_iter(self, runner):
print('upload after_train_iter')
Hookd的运行顺序
If custom hooks have same priority with default hooks, custom hooks will be triggered after default hooks.
先运行default hook,再运行custom hook
参考
目标检测(MMdetection)-HOOK机制
MMCV 核心组件分析(六): Hook
【官方文档】Customize hooks