本博文由TensorSense发表于PyTorch的hook及其在Grad-CAM中的应用,转载请注明出处。

一、hook简介

pytorch中的hook是一个非常有意思的概念,hook意为钩、挂钩、鱼钩。引用知乎用户“马索萌”对hook的解释:“(hook)相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。

简单讲,就是不修改主体,而实现额外功能。对应到在pytorch中,主体就是forward和backward,而额外的功能就是对模型的变量进行操作,如“提取”特征图,“提取”非叶子张量的梯度,修改张量梯度等等。

hook的出现与pytorch运算机制有关,pytorch在每一次运算结束后,会将中间变量释放,以节省内存空间,这些会被释放的变量包括非叶子张量的梯度,中间层的特征图等。但有时候,我们想可视化中间层的特征图,又不能改动模型主体代码,该怎么办呢?这时候就要用到hook了。

举个例子演示hook提取非叶子张量的梯度:

  1. import torch
  2. def grad_hook(grad):
  3. y_grad.append(grad)
  4. y_grad = list()
  5. x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
  6. y = x+1
  7. y.register_hook(grad_hook)
  8. z = torch.mean(y*y)
  9. z.backward()
  10. print("type(y): ", type(y))
  11. print("y.grad: ", y.grad)
  12. print("y_grad[0]: ", y_grad[0])
  13. >>> ('type(y): ', <class 'torch.Tensor'>)
  14. >>> ('y.grad: ', None)
  15. >>> ('y_grad[0]: ', tensor([[1.0000, 1.5000],
  16. [2.0000, 2.5000]]))

可以看到y.grad的值为None,这是因为y是非叶子结点张量,在z.backward()完成之后,y的梯度被释放掉以节省内存,但可以通过torch.Tensor的类方法register_hook将y的梯度提取出来。

二、PyTorch的四个hook

PyTorch(1.1.0版)有如下4个hook:

  • torch.Tensor.register_hook (Python method, in torch.Tensor)
  • torch.nn.Module.register_forward_hook (Python method, in torch.nn)
  • torch.nn.Module.register_backward_hook (Python method, in torch.nn)
  • torch.nn.Module.register_forward_pre_hook (Python method, in torch.nn)


这4个hook中有一个是应用于tensor的,另外3个是针对nn.Module的。

1. torch.Tensor.register_hook(hook)

功能:注册一个反向传播hook函数,这个函数是Tensor类里的,当计算tensor的梯度时自动执行。
为什么是backward?因为这个hook是针对tensor的,tensor中的什么东西会在计算结束后释放呢?
只有gradient嘛,所以是 backward hook.

形式: hook(grad) -> Tensor or None ,其中grad就是这个tensor的梯度。

返回值:a handle that can be used to remove the added hook by calling handle.remove()

应用场景举例:在hook函数中可对梯度grad进行in-place操作,即可修改tensor的grad值。
这是一个很酷的功能,例如当浅层的梯度消失时,可以对浅层的梯度乘以一定的倍数,用来增大梯度;
还可以对梯度做截断,限制梯度在某一区间,防止过大的梯度对权值参数进行修改。
下面举两个例子,例1是如何获取中间变量y的梯度例2是利用hook函数将变量x的梯度扩大2倍。

例1:**

  1. import torch
  2. y_grad = list()
  3. def grad_hook(grad):
  4. y_grad.append(grad)
  5. x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
  6. y = torch.pow(x, 2)
  7. z = torch.mean(y)
  8. h = y.register_hook(grad_hook)
  9. z.backward()
  10. print("y.grad: ", y.grad)
  11. print("y_grad[0]: ", y_grad[0])
  12. h.remove() # removes the hook
  13. >>> ('y.grad: ', None)
  14. >>> ('y_grad[0]: ', tensor([0.2500, 0.2500, 0.2500, 0.2500]))

可以看到当z.backward()结束后,张量y中的grad为None,因为y是非叶子节点张量,在梯度反传结束之后,被释放。
在对张量y的hook函数(grad_hook)中,将y的梯度保存到了y_grad列表中,因此可以在z.backward()结束后,仍旧可以在y_grad[0]中读到y的梯度为tensor([0.2500, 0.2500, 0.2500, 0.2500])

例2:

  1. import torch
  2. def grad_hook(grad):
  3. grad *= 2
  4. x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
  5. y = torch.pow(x, 2)
  6. z = torch.mean(y)
  7. h = x.register_hook(grad_hook)
  8. z.backward()
  9. print(x.grad)
  10. h.remove() # removes the hook
  11. >>> tensor([2., 2., 2., 2.])

原x的梯度为tensor([1., 1., 1., 1.]),经grad_hook操作后,梯度为tensor([2., 2., 2., 2.])。

2. torch.nn.Module.register_forward_hook

功能:Module前向传播中的hook,module在前向传播后,自动调用hook函数。
形式:hook(module, input, output) -> None。注意不能修改input和output
返回值:a handle that can be used to remove the added hook by calling handle.remove()

应用场景举例:用于提取特征图
举例:假设网络由卷积层conv1和池化层pool1构成,输入一张4*4的图片,现采用forward_hook获取module——conv1之后的feature maps,示意图如下:
image.png

  1. import torch
  2. import torch.nn as nn
  3. class Net(nn.Module):
  4. def __init__(self):
  5. super(Net, self).__init__()
  6. self.conv1 = nn.Conv2d(1, 2, 3)
  7. self.pool1 = nn.MaxPool2d(2, 2)
  8. def forward(self, x):
  9. x = self.conv1(x)
  10. x = self.pool1(x)
  11. return x
  12. def farward_hook(module, data_input, data_output):
  13. fmap_block.append(data_output)
  14. input_block.append(data_input)
  15. if __name__ == "__main__":
  16. # 初始化网络
  17. net = Net()
  18. net.conv1.weight[0].fill_(1)
  19. net.conv1.weight[1].fill_(2)
  20. net.conv1.bias.data.zero_()
  21. # 注册hook
  22. fmap_block = list()
  23. input_block = list()
  24. net.conv1.register_forward_hook(farward_hook)
  25. # inference
  26. fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
  27. output = net(fake_img)
  28. # 观察
  29. print("output shape: {}\noutput value: {}\n".format(output.shape, output))
  30. print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
  31. print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

首先初始化一个网络,卷积层有两个卷积核,权值分别为全1和全2,bias设置为0,池化层采用2*2的最大池化。

在进行forward之前对module——conv1注册了forward_hook函数,然后执行前向传播(output=net(fake_img)),当前向传播完成后,fmap_block列表中的第一个元素就是conv1层输出的特征图了。

这里注意观察farward_hook函数有data_input和data_output两个变量,特征图是data_output这个变量,而data_input是conv1层的输入数据,conv1层的输入是一个tuple的形式。

  1. OUT:
  2. output shape: torch.Size([1, 2, 1, 1])
  3. output value: tensor([[[[ 9.]],
  4. [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
  5. feature maps shape: torch.Size([1, 2, 2, 2])
  6. output value: tensor([[[[ 9., 9.],
  7. [ 9., 9.]],
  8. [[18., 18.],
  9. [18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)
  10. input shape: torch.Size([1, 1, 4, 4])
  11. input value: (tensor([[[[1., 1., 1., 1.],
  12. [1., 1., 1., 1.],
  13. [1., 1., 1., 1.],
  14. [1., 1., 1., 1.]]]]),)

下面剖析一下module是怎么样调用hook函数的:
**

1.output = net(fake_img)

net是一个module类,对module执行 module(input)是会调用module.call

2.module.call

在module.call中执行流程如下:

  1. def __call__(self, *input, **kwargs):
  2. for hook in self._forward_pre_hooks.values():
  3. hook(self, input)
  4. if torch._C._get_tracing_state():
  5. result = self._slow_forward(*input, **kwargs)
  6. else:
  7. result = self.forward(*input, **kwargs)
  8. for hook in self._forward_hooks.values():
  9. hook_result = hook(self, input, result)
  10. if hook_result is not None:
  11. raise RuntimeError(
  12. "forward hooks should never return any values, but '{}'"
  13. "didn't return None".format(hook))
  14. ...省略

首先判断module(这里是net)是否有forwardprehook,即在执行forward之前的hook;
然后执行forward;
forward结束之后才到forward_hook。
但是这里主要了,现在执行的是net.call,我们组成的hook是在module——net.conv1中,
所以第2个跳转是在net.__call
的 result = self.forward(input, *kwargs)

3.net.forward

  1. def forward(self, x):
  2. x = self.conv1(x)
  3. x = self.pool1(x)
  4. return x

在net.forward中,首先执行self.conv1(x), 而 conv1是一个nn.Conv2d(也是一个module类)。
在2中有说到,对module执行 module(input)是会调用module.call,因此第四步

4.nn.Conv2d.call

在nn.Conv2d.call中与2中说到的流程是一样的,再看一遍代码:

  1. def __call__(self, *input, **kwargs):
  2. for hook in self._forward_pre_hooks.values():
  3. hook(self, input)
  4. if torch._C._get_tracing_state():
  5. result = self._slow_forward(*input, **kwargs)
  6. else:
  7. result = self.forward(*input, **kwargs)
  8. for hook in self._forward_hooks.values():
  9. hook_result = hook(self, input, result)
  10. if hook_result is not None:
  11. raise RuntimeError(
  12. "forward hooks should never return any values, but '{}'"
  13. "didn't return None".format(hook))

在这里终于要执行我们注册的forward_hook函数了,就在hook_result = hook(self, input, result)这里!
看到这里我们需要注意两点:

  1. hook_result = hook(self, input, result)中的input和result不可以修改!这里的input对应forward_hook函数中的data_input,result对应forward_hook函数中的data_output,在conv1中,input就是该层的输入数据result就是经过conv1层操作之后的输出特征图。虽然可以通过hook来对这些数据操作,但是不能修改这些值,否则会破坏模型的计算。
  2. 注册的hook函数是不能带返回值的,否则抛出异常,这个可以从代码中看到.

总结一下调用流程:

net(fake_img) —> net.call : result = self.forward(input, *kwargs) —>
net.forward: x = self.conv1(x) —> conv1.call:hook_result = hook(self, input, result)
hook就是我们注册的forward_hook函数了。

3. torch.nn.Module.register_forward_pre_hook

功能:执行forward()之前调用hook函数。
形式:hook(module, input) -> None
应用场景举例:暂时没碰到过,希望读者朋友补充register_forward_pre_hook相关应用场景。

4.torch.nn.Module.register_backward_hook

功能:Module反向传播中的hook,每次计算module的梯度后,自动调用hook函数。
形式:hook(module, grad_input, grad_output) -> Tensor or None
注意事项:当module有多个输入或输出时,grad_input和grad_output是一个tuple。
返回值:a handle that can be used to remove the added hook by calling handle.remove()

应用场景举例:例如提取特征图的梯度
举例:采用register_backward_hook实现特征图梯度的提取,并结合Grad-CAM(基于类梯度的类激活图可视化)方法对卷积神经网络的学习模式进行可视化。

Grad-CAM是对特征图进行求梯度,将每一张特征图上的梯度求平均得到权值(特征图的梯度是element-wise的)。求梯度时并不采用网络的输出,而是采用类向量,即one-hot向量。