1 还有很多call相关功能要执行

错误示范:
image.png


module的call,还有这么多东西,然后才绕到forward
image.png

2 具体执行了什么

看源码,其嵌入了hook,实际执行顺序是:

  1. call -> _call_impl(self, input, *kwargs)
  2. forward_pre_hook,先执行全局的hook,再执行模块自己的hook
    1. from torch.nn.modules.module import register_module_forward_pre_hook 注册的功能
      1. hook(module, input) -> None or modified input
      2. 注册函数都是返回一个可以使用with上下文,可以remove移除hook的句柄,后面不再累述
    2. nn.Module.register_forward_pre_hook
    3. 相当于执行了 input = hook(self, input)
      1. 当然这里比较智能,如果返回值是None,则不会替换inpu。torch的hook都是这样的处理模式。
      2. 因为input本来就是打包的,非tuple对象会强制转为tuple存储回input。其实hook里都按打包的input处理就好了。
  3. _get_tracing_state模式?
    1. 执行_slow_forward,其内部最终也是会执行 forward 的
    2. 否则执行 forward
  4. forward_hook,同pre_hook,先 register_module_forward_hook,再register_forward_hook
    1. result = hook(self, input, result)
  5. backward_hooks,这个感觉不太用的到。。。
    1. grad_output = hook(module, grad_input, grad_output)
    2. result.grad_fn会绑定这个hook

2a是全局所有的module,2b是指定实例化的某个model。
注意2b是指定对象,只对实例化的某个model有用,无法针对某一类module处理。
如果确有针对某一类module的需求,可以使用全局的,配合条件判断isinstance(module, nn.Linear)使用。

3 hook的应用举例:查看每步feature_size

全局,给所有module添加hook

  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn.modules.module import register_module_forward_pre_hook
  5. from pyxllib.xl import shorten # 将文本折叠为一行的函数
  6. def build_dataloader(n):
  7. """ 生成n个数据样本,每个样本是 5个 [0, 6) 的整数
  8. 按照数值和分成两类, 10<sum(x)<20是1类,其他是0类
  9. """
  10. data = np.random.randint(0, 6, [n, 5])
  11. dataset = [(torch.FloatTensor(x), int(10 < sum(x) < 20)) for x in data]
  12. return torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
  13. class NumNet(nn.Module):
  14. def __init__(self):
  15. super().__init__()
  16. self.classifier = nn.Sequential(
  17. nn.Linear(5, 10),
  18. nn.ReLU(),
  19. nn.Linear(10, 5),
  20. nn.ReLU(),
  21. nn.Linear(5, 2),
  22. nn.ReLU(),
  23. )
  24. def forward(self, batched_inputs):
  25. x, y = batched_inputs
  26. logits = self.classifier(x)
  27. loss = nn.functional.cross_entropy(logits, y)
  28. return loss
  29. def print_feature_size(module, inputs):
  30. inputdata = inputs[0] # 有很多位置参数会传给module,inputs[0]才是这里原来的batched_inputs
  31. if isinstance(inputdata, (list, tuple)):
  32. x, y = inputdata # NumNet传入的batched_inputs
  33. else:
  34. x = inputdata # 传给Linear、ReLU等forward的时候只有x
  35. print(x.shape, '-->', shorten(module))
  36. if __name__ == '__main__':
  37. register_module_forward_pre_hook(print_feature_size)
  38. dataloader = build_dataloader(20) # 生成20个数据,batch_size=16,跑两轮
  39. model = NumNet()
  40. for batched_inputs in dataloader:
  41. model(batched_inputs)
  42. print('= ' * 20)
  43. # torch.Size([16, 5]) --> NumNet( (classifier): Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() (4): Linear(in_features=5, ou...
  44. # torch.Size([16, 5]) --> Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() (4): Linear(in_features=5, out_features=2, bias=Tru...
  45. # torch.Size([16, 5]) --> Linear(in_features=5, out_features=10, bias=True)
  46. # torch.Size([16, 10]) --> ReLU()
  47. # torch.Size([16, 10]) --> Linear(in_features=10, out_features=5, bias=True)
  48. # torch.Size([16, 5]) --> ReLU()
  49. # torch.Size([16, 5]) --> Linear(in_features=5, out_features=2, bias=True)
  50. # torch.Size([16, 2]) --> ReLU()
  51. # = = = = = = = = = = = = = = = = = = = =
  52. # torch.Size([4, 5]) --> NumNet( (classifier): Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() (4): Linear(in_features=5, ou...
  53. # torch.Size([4, 5]) --> Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() (4): Linear(in_features=5, out_features=2, bias=Tru...
  54. # torch.Size([4, 5]) --> Linear(in_features=5, out_features=10, bias=True)
  55. # torch.Size([4, 10]) --> ReLU()
  56. # torch.Size([4, 10]) --> Linear(in_features=10, out_features=5, bias=True)
  57. # torch.Size([4, 5]) --> ReLU()
  58. # torch.Size([4, 5]) --> Linear(in_features=5, out_features=2, bias=True)
  59. # torch.Size([4, 2]) --> ReLU()
  60. # = = = = = = = = = = = = = = = = = = = =

这是检查forward前data的尺寸是多大,然后丢入该层module的;
如果要检查forward后尺寸多大,可以使用:
from torch.nn.modules.module import register_module_forward_hook。

仅给指定实例化的某个model添加hook

前面不变,main改成:

  1. if __name__ == '__main__':
  2. # register_module_forward_pre_hook(print_feature_size)
  3. dataloader = build_dataloader(20) # 生成20个数据,batch_size=16,跑两轮
  4. model = NumNet()
  5. model.register_forward_pre_hook(print_feature_size)
  6. for batched_inputs in dataloader:
  7. model(batched_inputs)
  8. print('= ' * 20)
  9. # torch.Size([16, 5]) --> NumNet( (classifier): Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() (4): Linear(in_features=5, ou...
  10. # = = = = = = = = = = = = = = = = = = = =
  11. # torch.Size([4, 5]) --> NumNet( (classifier): Sequential( (0): Linear(in_features=5, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=5, bias=True) (3): ReLU() (4): Linear(in_features=5, ou...
  12. # = = = = = = = = = = = = = = = = = = = =