Hook函数

hook函数机制:不改变主体,实现额外功能,像一个挂件,挂钩,hook.由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
PyTorch提供四种Hook函数:
1、torch.Tensor.register_hook(hook)
2、torch.nn.Module.register_forward_hook
3、torch.nn.Module.register_forward_pre_hook
4、torch.nn.Module.register_backward_hook

1、torch.Tensor.register_hook

功能:注册一个反向传播hook函数,Hook函数仅一个输入参数,为张量的梯度。Hook不应修改其参数梯度值,但可以选择返回一个新的梯度,该梯度将代替grad使用。
hook(grad) -> Tensor or None
结合代码进行讲解:
Registers a backward hook.
The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the following signature:
hook(grad) -> Tensor or None
The hook should not modify its argument, but it can optionally return a new gradient which will be used in place of grad.
This function returns a handle with a method handle.remove() that removes the hook from the module.
注册反向钩子。
每当计算一个关于张量的梯度时,钩子就会被调用。钩子应该有以下签名:
hook(grad) ->张量或None
钩子不应该修改它的参数,但是它可以返回一个新的梯度来代替grad。
这个函数返回一个句柄和一个方法handle.remove(),该方法从模块中移除钩子。

  1. >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
  2. >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
  3. >>> v.backward(torch.tensor([1., 2., 3.]))
  4. >>> v.grad
  5. 2
  6. 4
  7. 6
  8. [torch.FloatTensor of size (3,)]
  9. >>> h.remove() # removes the hook
  1. import torch
  2. # x,y 为leaf节点,也就是说,在计算的时候,PyTorch只会保留此节点的梯度值
  3. x = torch.tensor([3.], requires_grad=True)
  4. y = torch.tensor([5.], requires_grad=True)
  5. # a,b均为中间值,在计算梯度时,此部分会被释放掉
  6. a = x + y
  7. b = x * y
  8. c = a * b
  9. # 新建列表,用于存储Hook函数保存的中间梯度值
  10. a_grad = []
  11. def hook_grad(grad):
  12. a_grad.append(grad)
  13. # register_hook的参数为一个函数
  14. handle = a.register_hook(hook_grad)
  15. c.backward()
  16. # 只有leaf节点才会有梯度值
  17. print('gradient:',x.grad, y.grad, a.grad, b.grad, c.grad)
  18. # Hook函数保留下来的中间节点a的梯度
  19. print('a_grad:', a_grad[0])
  20. # 移除Hook函数
  21. handle.remove()
  22. #out
  23. gradient: tensor([55.]) tensor([39.]) None None None
  24. a_grad: tensor([15.])

2、torch.nn.Module.register_forward_hook

功能:注册module的前向传播Hook函数
参数:

  • module:当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

在模块上注册一个前向挂钩。
每当forward()计算出输出后,都会调用该钩子。签名如下:
hook(模块,输入,输出)->
输入只包含给模块的位置参数。关键字参数不会传递给钩子,而只传递给转发。钩子可以修改输出。它可以就地修改输入,但不会对forward产生影响,因为它是在forward()调用之后调用的。
Registers a forward hook on the module.

The hook will be called every time after forward() has computed an output. It should have the following signature:

hook(module, input, output) -> None or modified output
The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called.
结合代码进行讲解:

  1. import torch
  2. import torch.nn as nn
  3. # 构建网网络,一个卷积层一个池化层
  4. class Net(nn.Module):
  5. def __init__(self):
  6. super(Net,self).__init__()
  7. self.conv1 = nn.Conv2d(1, 2, 3)
  8. self.pool1 = nn.MaxPool2d(2)
  9. def forward(self, x):
  10. x = self.conv1(x)
  11. x = self.pool1(x)
  12. return x
  13. # 初始化网络
  14. net = Net()
  15. # detach将张量分离
  16. net.conv1.weight[0].detach().fill_(1)
  17. net.conv1.weight[1].detach().fill_(2)
  18. net.conv1.bias.detach().zero_()
  19. # 构建两个列表用于保存信息
  20. fmap_block = []
  21. input_block = []
  22. def forward_hook(module, data_input, data_output):
  23. fmap_block.append(data_output)
  24. input_block.append(data_input)
  25. # 注册Hook
  26. net.conv1.register_forward_hook(forward_hook)
  27. # 输入数据
  28. fake_img = torch.ones((1, 1, 4, 4))
  29. output = net(fake_img)
  30. # 观察结果
  31. # 卷积神经网络输出维度和结果
  32. print("output share:{}\noutput value:{}\n".format(output.size(),output))
  33. # 卷积神经网络Hook函数返回的结果
  34. print("feature map share:{}\noutput value:{}\n".format(fmap_block[0].shape,fmap_block[0]))
  35. # 输入的信息
  36. print("input share:{}\ninput value:{}\n".format(input_block[0][0].size(),input_block[0][0]))

3、torch.nn.Module.register_forward_pre_hook

功能:注册module前向传播前的hook函数。
参数:

  • module:当前网络层
  • input:当前网络层输入数据

Registers a forward pre-hook on the module.
The hook will be called every time before forward() is invoked. It should have the following signature:
hook(module, input) -> None or modified input
The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple).
Returns
a handle that can be used to remove the added hook by calling handle.remove()
Return type
torch.utils.hooks.RemovableHandle
在模块上注册一个前向预挂钩。
每次调用forward()之前都会调用该钩子。签名如下:
None或修改的输入
输入只包含给模块的位置参数。关键字参数不会传递给钩子,而只传递给转发。钩子可以修改输入。用户可以返回一个元组,也可以在钩子中返回一个修改过的值。如果返回的是单个值(除非该值已经是一个元组),则将该值包装成一个元组。
返回
可以通过调用handle.remove()来移除添加的钩子的句柄。
返回类型
torch.utils.hooks.RemovableHandle

4、torch.nn.Module.register_backward_hook

功能:注册module反向传播的hook函数。
参数:

  • module:当前网络层
  • grad_input:当前网络层输入梯度数据
  • grad_output:当前网络层输出梯度数据

register_backward_hook(hook: Callable[[Module, Union[Tuple[torch.Tensor, …], torch.Tensor], Union[Tuple[torch.Tensor, …], torch.Tensor]], Union[None, torch.Tensor]]) → torch.utils.hooks.RemovableHandle[SOURCE]
Registers a backward hook on the module.
WARNING
The current implementation will not have the presented behavior for complex Module that perform many operations. In some failure cases, grad_input and grad_output will only contain the gradients for a subset of the inputs and outputs. For such Module, you should use torch.Tensor.register_hook() directly on a specific input or output to get the required gradients.
The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:
hook(module, grad_input, grad_output) -> Tensor or None
The grad_input and grad_output may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments.
Returns
a handle that can be used to remove the added hook by calling handle.remove()
Return type
torch.utils.hooks.RemovableHandle

参考

https://zhuanlan.zhihu.com/p/73868323