forword
我们在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。
class Module(nn.Module):
def __init__(self):
super().__init__()
# ......
def forward(self, x):
# ......
return x
data = ...... # 输入数据
# 实例化一个对象
model = Module()
# 前向传播
model(data)
# 而不是使用下面的
# model.forward(data)
但是实际上model(data)是等价于model.forward(data),model(data)之所以等价于model.forward(data),就是因为在类(class)中使用了call函数,
class Student:
def __call__(self):
print('I can be called like a function')
a = Student()
a()
I can be called like a function
因为 PyTorch 中的大部分方法都继承自 torch.nn.Module,而 torch.nn.Module 的call(self)函数中会返回 forward()函数 的结果,因此PyTroch中的 forward()函数等于是被嵌套在了call(self)函数中;因此forward()函数可以直接通过类名被调用,而不用实例化对象