不带模型参数
对于不带模型参数的自定义层,我们既可以对它进行实例化,也可以将他作为一层隐藏层放在网络当中。对于这种自定义层,可以不需要明确传入的参数形状。
class CenteredLayer(nn.Module):def __init__(self, **kwargs):super(CenteredLayer, self).__init__(**kwargs)def forward(self, x):return x - x.mean()# 可以实例化该层layer = CenteredLayer()print(layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)))# 也可以使用它构造更复杂的模型net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())y = net(torch.rand(2, 8))print(y.mean().item())
含模型参数
在自定义含模型参数的层时,我们应该将参数定义成
Parameter,除了直接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。
class MyDense(nn.Module):def __init__(self):super(MyDense, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])self.params.append(nn.Parameter(torch.randn(4, 1)))def forward(self, x):for i in range(len(self.params)):x = torch.mm(x, self.params[i])return xnet = MyDense()print(net(torch.rand(3, 4)))print(net)结果:tensor([[-0.8253],[ 4.7919],[ 6.3828]], grad_fn=<MmBackward>)MyDense((params): ParameterList((0): Parameter containing: [torch.FloatTensor of size 4x4](1): Parameter containing: [torch.FloatTensor of size 4x4](2): Parameter containing: [torch.FloatTensor of size 4x4](3): Parameter containing: [torch.FloatTensor of size 4x1]))
而
ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,然后可以按照字典的规则使用了。例如使用update()新增参数,使用keys()返回所有键值,使用items()返回所有键值对等等,可参考官方文档。
class MyDictDense(nn.Module):def __init__(self):super(MyDictDense, self).__init__()self.params = nn.ParameterDict({'linear1': nn.Parameter((torch.randn(4, 4))),'linear2': nn.Parameter((torch.randn(4, 1)))})self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})def forward(self, x, choice='linear1'):return torch.mm(x, self.params[choice])net = MyDictDense()print(net(torch.rand(3, 4), 'linear2'))结果:tensor([[ 0.2083],[ 1.0280],[-0.3583]], grad_fn=<MmBackward>)
使用update更新参数时,会将新的参数加入到字典中。不过这种用字典的方式貌似是无序的,顺序还得通过自己手动定义forward函数来实现。
我们也完全可以使用这些自定义层来构造模型:
net = nn.Sequential(MyDictDense(),MyDense(),)print(net)结果:Sequential((0): MyDictDense((params): ParameterDict((linear1): Parameter containing: [torch.FloatTensor of size 4x4](linear2): Parameter containing: [torch.FloatTensor of size 4x1](linear3): Parameter containing: [torch.FloatTensor of size 4x2]))(1): MyDense((params): ParameterList((0): Parameter containing: [torch.FloatTensor of size 4x4](1): Parameter containing: [torch.FloatTensor of size 4x4](2): Parameter containing: [torch.FloatTensor of size 4x4](3): Parameter containing: [torch.FloatTensor of size 4x1])))
