PyTorch: 定义新的autograd函数

译者:@yongjay13@speedmancs

校对者:@bringtree

本例中的全连接神经网络有一个隐藏层, 后接ReLU激活层, 并且不带偏置参数. 训练时通过最小化欧式距离的平方, 来学习从x到y的映射.

在此实现中, 我们使用PyTorch变量上的函数来进行前向计算, 然后用PyTorch的autograd计算梯度

我们还实现了一个定制化的autograd函数, 用于ReLU函数.

  1. import torch
  2. from torch.autograd import Variable
  3. class MyReLU(torch.autograd.Function):
  4. """
  5. 我们可以通过子类实现我们自己定制的autograd函数
  6. torch.autograd.Function和执行在Tensors上运行的向前和向后通行证.
  7. """
  8. @staticmethod
  9. def forward(ctx, input):
  10. """
  11. 在正向传递中,我们收到一个包含输入和返回张量的张量,其中包含输出.
  12. ctx是一个上下文对象,可用于存储反向计算的信息.
  13. 您可以使用ctx.save_for_backward方法缓存任意对象以用于后向传递.
  14. """
  15. ctx.save_for_backward(input)
  16. return input.clamp(min=0)
  17. @staticmethod
  18. def backward(ctx, grad_output):
  19. """
  20. 在后向传递中,我们收到一个张量,其中包含相对于输出的损失梯度,
  21. 我们需要计算相对于输入的损失梯度.
  22. """
  23. input, = ctx.saved_tensors
  24. grad_input = grad_output.clone()
  25. grad_input[input < 0] = 0
  26. return grad_input
  27. dtype = torch.FloatTensor
  28. # dtype = torch.cuda.FloatTensor # 取消注释以在GPU上运行
  29. # N 批量大小; D_in是输入尺寸;
  30. # H是隐藏尺寸; D_out是输出尺寸.
  31. N, D_in, H, D_out = 64, 1000, 100, 10
  32. # 创建随机张量来保存输入和输出,并将它们包装在变量中.
  33. x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
  34. y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)
  35. # 为权重创建随机张量,并将其包装在变量中.
  36. w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
  37. w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)
  38. learning_rate = 1e-6
  39. for t in range(500):
  40. # 为了应用我们的函数,我们使用Function.apply方法.我们把它称为'relu'.
  41. relu = MyReLU.apply
  42. # 正向传递:使用变量上的运算来计算预测的y;
  43. # 我们使用我们的自定义autograd操作来计算ReLU.
  44. y_pred = relu(x.mm(w1)).mm(w2)
  45. # 计算和打印损失
  46. loss = (y_pred - y).pow(2).sum()
  47. print(t, loss.data[0])
  48. # 使用autograd来计算反向传递.
  49. loss.backward()
  50. # 使用梯度下降更新权重
  51. w1.data -= learning_rate * w1.grad.data
  52. w2.data -= learning_rate * w2.grad.data
  53. # 更新权重后手动将梯度归零
  54. w1.grad.data.zero_()
  55. w2.grad.data.zero_()