forword

我们在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

  1. class Module(nn.Module):
  2. def __init__(self):
  3. super().__init__()
  4. # ......
  5. def forward(self, x):
  6. # ......
  7. return x
  8. data = ...... # 输入数据
  9. # 实例化一个对象
  10. model = Module()
  11. # 前向传播
  12. model(data)
  13. # 而不是使用下面的
  14. # model.forward(data)

但是实际上model(data)是等价于model.forward(data),model(data)之所以等价于model.forward(data),就是因为在类(class)中使用了call函数,

  1. class Student:
  2. def __call__(self):
  3. print('I can be called like a function')
  4. a = Student()
  5. a()
  1. I can be called like a function

因为 PyTorch 中的大部分方法都继承自 torch.nn.Module,而 torch.nn.Module 的call(self)函数中会返回 forward()函数 的结果,因此PyTroch中的 forward()函数等于是被嵌套在了call(self)函数中;因此forward()函数可以直接通过类名被调用,而不用实例化对象