PyTorch:控制流 + 权重共享
为了展示 PyTorch 动态图的强大功能,我们将实现一个非常奇怪的模型:一个从三到五阶(动态变化)的多项式,在每次正向传递中选择一个 3 到 5 之间的一个随机数,并将这个随机数作为阶数,第四和第五阶共用同一个权重来多次重复计算。
import randomimport torchimport mathclass DynamicNet(torch.nn.Module):def __init__(self):"""In the constructor we instantiate five parameters and assign them as members."""super().__init__()self.a = torch.nn.Parameter(torch.randn(()))self.b = torch.nn.Parameter(torch.randn(()))self.c = torch.nn.Parameter(torch.randn(()))self.d = torch.nn.Parameter(torch.randn(()))self.e = torch.nn.Parameter(torch.randn(()))def forward(self, x):"""For the forward pass of the model, we randomly choose either 4, 5and reuse the e parameter to compute the contribution of these orders.Since each forward pass builds a dynamic computation graph, we can use normalPython control-flow operators like loops or conditional statements whendefining the forward pass of the model.Here we also see that it is perfectly safe to reuse the same parameter manytimes when defining a computational graph."""y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3for exp in range(4, random.randint(4, 6)):y = y + self.e * x ** expreturn ydef string(self):"""Just like any class in Python, you can also define custom method on PyTorch modules"""return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'# Create Tensors to hold input and outputs.x = torch.linspace(-math.pi, math.pi, 2000)y = torch.sin(x)# Construct our model by instantiating the class defined abovemodel = DynamicNet()# Construct our loss function and an Optimizer. Training this strange model with# vanilla stochastic gradient descent is tough, so we use momentumcriterion = torch.nn.MSELoss(reduction='sum')optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)for t in range(30000):# Forward pass: Compute predicted y by passing x to the modely_pred = model(x)# Compute and print lossloss = criterion(y_pred, y)if t % 2000 == 1999:print(t, loss.item())# Zero gradients, perform a backward pass, and update the weights.optimizer.zero_grad()loss.backward()optimizer.step()print(f'Result: {model.string()}')
脚本的总运行时间:(0 分钟 0.000 秒)
下载 Jupyter 笔记本:dynamic_net.ipynb
由 Sphinx 画廊生成的画廊
