想训练一个网络,网络本身很简单结构,无关的部分我都注释了,如下:

    1. class ResnetGenerator(nn.Module):
    2. def __init__(self, input_nc=3, output_nc=3, ngf=8, n_blocks=4):
    3. super(ResnetGenerator, self).__init__()
    4. self.resblock = []
    5. for i in range(n_blocks):
    6. self.resblock += [ResidualBasic_in(output_nc)]
    7. def forward(self, y_hr, x_hr):
    8. y_hr = self.resblock(torch.add(x_hr, y_hr))
    9. return y_hr

    但是报错了,报错信息如下:
    TypeError: 'list' object is not callable

    原因是代码中的resblock我定义成了一个list,而非一个nn.Module模块,需要使用nn.Sequential

    修改后的代码如下,只需要在第7行加上self``.resblock = nn.Sequential(*``self``.resblock)

    1. class ResnetGenerator(nn.Module):
    2. def __init__(self, input_nc=3, output_nc=3, ngf=8, n_blocks=4):
    3. super(ResnetGenerator, self).__init__()
    4. self.resblock = []
    5. for i in range(n_blocks):
    6. self.resblock += [ResidualBasic_in(output_nc)]
    7. self.resblock = nn.Sequential(*self.resblock)
    8. def forward(self, y_hr, x_hr):
    9. y_hr = self.resblock(torch.add(x_hr, y_hr))
    10. return y_hr