1. Hook函数概念
  2. Hook函数与特征图提取
  3. CAM(class activation map, 类激活图)

一、Hook

Hook函数机制:不改变主体,实现额外功能,像一个挂件,挂钩,hook

pytorch 是动态图机制,当动态图运算结束之后,一些中间变量(特征图、非叶子节点的梯度等)会被释放掉。但是我们有提取中间变量的需求,所以有了Hook函数

PyTorch 提供的四种 Hook 函数

  1. torch.Tensor.register_hook(hook) # 针对tensor的
  2. # 针对module的Hook函数
  3. torch.nn.Module.register_forward_hook
  4. torch.nn.Module.register_forward_pre_hook
  5. torch.nn.Module.register_backward_hook

01. Tensor.register_hook()

  1. hook(grad) -> Tensor or None
  • 功能:注册一个反向传播hook函数。因为非叶子节点的tensor会在反向传播过程中梯度会消失。
  • Hook函数仅一个输入参数:张量的梯度

    02. Module.register_forward_hook()

    1. hook(module, input, output) -> None
  • 功能:注册module的前向传播hook函数

  • 参数:
    • module : 当前网络层
    • input :当前网络层输入数据
    • output :当前网络层输出数据

      二、Hook函数与特征图提取

三、CAM