想训练一个网络,网络本身很简单结构,无关的部分我都注释了,如下:
class ResnetGenerator(nn.Module):def __init__(self, input_nc=3, output_nc=3, ngf=8, n_blocks=4):super(ResnetGenerator, self).__init__()self.resblock = []for i in range(n_blocks):self.resblock += [ResidualBasic_in(output_nc)]def forward(self, y_hr, x_hr):y_hr = self.resblock(torch.add(x_hr, y_hr))return y_hr
但是报错了,报错信息如下:TypeError: 'list' object is not callable
原因是代码中的resblock我定义成了一个list,而非一个nn.Module模块,需要使用nn.Sequential
修改后的代码如下,只需要在第7行加上self``.resblock = nn.Sequential(*``self``.resblock)
class ResnetGenerator(nn.Module):def __init__(self, input_nc=3, output_nc=3, ngf=8, n_blocks=4):super(ResnetGenerator, self).__init__()self.resblock = []for i in range(n_blocks):self.resblock += [ResidualBasic_in(output_nc)]self.resblock = nn.Sequential(*self.resblock)def forward(self, y_hr, x_hr):y_hr = self.resblock(torch.add(x_hr, y_hr))return y_hr
