优化器实现方向传播

    1. @HOOKS.register_module()
    2. class OptimizerHook(Hook):
    3. def __init__(self, grad_clip=None):
    4. self.grad_clip = grad_clip
    5. def clip_grads(self, params):
    6. params = list(
    7. filter(lambda p: p.requires_grad and p.grad is not None, params))
    8. if len(params) > 0:
    9. return clip_grad.clip_grad_norm_(params, **self.grad_clip)
    10. def after_train_iter(self, runner):
    11. runner.optimizer.zero_grad()
    12. runner.outputs['loss'].backward()
    13. if self.grad_clip is not None:
    14. grad_norm = self.clip_grads(runner.model.parameters())
    15. if grad_norm is not None:
    16. # Add grad norm to the logger
    17. runner.log_buffer.update({'grad_norm': float(grad_norm)},
    18. runner.outputs['num_samples'])
    19. runner.optimizer.step()