使用 numpy 和 scipy 创建扩展

译者:@飞龙

作者: Adam Paszke

这个教程中, 我们将完成以下两个任务:

  1. 创建不带参数的神经网络层

    > 这会调用 *numpy, 作为其实现的一部分

  2. 创建带有可学习的权重的神经网络层

    > 这会调用 *SciPy, 作为其实现的一部分

  1. import torch
  2. from torch.autograd import Function
  3. from torch.autograd import Variable

无参示例

这一层并不做任何有用的, 或者数学上正确的事情.

它被恰当地命名为 BadFFTFunction

层的实现

  1. from numpy.fft import rfft2, irfft2
  2. class BadFFTFunction(Function):
  3. def forward(self, input):
  4. numpy_input = input.numpy()
  5. result = abs(rfft2(numpy_input))
  6. return torch.FloatTensor(result)
  7. def backward(self, grad_output):
  8. numpy_go = grad_output.numpy()
  9. result = irfft2(numpy_go)
  10. return torch.FloatTensor(result)
  11. # 由于这一层没有任何参数, 我们可以
  12. # 仅仅将其声明为一个函数, 而不是 nn.Module 类
  13. def incorrect_fft(input):
  14. return BadFFTFunction()(input)

所创建的层的使用示例:

  1. input = Variable(torch.randn(8, 8), requires_grad=True)
  2. result = incorrect_fft(input)
  3. print(result.data)
  4. result.backward(torch.randn(result.size()))
  5. print(input.grad)

参数化示例

它实现了带有可学习的权重的层.

它使用可学习的核, 实现了互相关.

在深度学习文献中, 它容易和卷积混淆.

反向过程计算了输入和滤波的梯度.

实现:

要注意, 实现作为一个演示, 我们并不验证它的正确性

  1. from scipy.signal import convolve2d, correlate2d
  2. from torch.nn.modules.module import Module
  3. from torch.nn.parameter import Parameter
  4. class ScipyConv2dFunction(Function):
  5. @staticmethod
  6. def forward(ctx, input, filter):
  7. result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
  8. ctx.save_for_backward(input, filter)
  9. return torch.FloatTensor(result)
  10. @staticmethod
  11. def backward(ctx, grad_output):
  12. input, filter = ctx.saved_tensors
  13. grad_output = grad_output.data
  14. grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
  15. grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
  16. return Variable(torch.FloatTensor(grad_input)), \
  17. Variable(torch.FloatTensor(grad_filter))
  18. class ScipyConv2d(Module):
  19. def __init__(self, kh, kw):
  20. super(ScipyConv2d, self).__init__()
  21. self.filter = Parameter(torch.randn(kh, kw))
  22. def forward(self, input):
  23. return ScipyConv2dFunction.apply(input, self.filter)

示例用法:

  1. module = ScipyConv2d(3, 3)
  2. print(list(module.parameters()))
  3. input = Variable(torch.randn(10, 10), requires_grad=True)
  4. output = module(input)
  5. print(output)
  6. output.backward(torch.randn(8, 8))
  7. print(input.grad)