PyTorch
    PyTorch模型源代码中的ctx参数,总结如下:

    • ctx是context的缩写,翻译成”上下文; 环境”
    • 主要用在自定义torch.autograd.Function
    • ctx专门用在静态方法
    • self指的是实例对象; 而ctx用在静态方法中,调用的时候不需要实例化对象,直接通过类名就可以调用,所以self在静态方法中没有意义
    • 自定义的forward()方法和backward()方法的第一个参数必须是ctx; ctx可以保存forward()中的变量,以便在backward()中继续使用,下一条是具体的示例
    • **ctx.save_for_backward(a, b)**能够保存**forward()**静态方法中的张量,从而可以在backward()静态方法中调用,具体地,下面地代码通过**a, b = ctx.saved_tensors**重新得到a和b
    • **ctx.needs_input_grad**是一个元组,元素是True或者False表示**forward()**中对应的输入是否需要求导,比如ctx.needs_input_grad[0]指的是下面forwad()代码中indices是否需要求导

      1. class SpecialSpmmFunction(torch.autograd.Function):
      2. """
      3. Special function for only sparse region backpropataion layer.
      4. """
      5. # 自定义前向传播过程
      6. @staticmethod
      7. def forward(ctx, indices, values, shape, b):
      8. assert indices.requires_grad == False
      9. a = torch.sparse_coo_tensor(indices, values, shape)
      10. ctx.save_for_backward(a, b)
      11. ctx.N = shape[0]
      12. return torch.matmul(a, b)
      13. # 自定义反向传播过程
      14. @staticmethod
      15. def backward(ctx, grad_output):
      16. a, b = ctx.saved_tensors
      17. grad_values = grad_b = None
      18. if ctx.needs_input_grad[1]:
      19. grad_a_dense = grad_output.matmul(b.t())
      20. edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
      21. grad_values = grad_a_dense.view(-1)[edge_idx]
      22. if ctx.needs_input_grad[3]:
      23. grad_b = a.t().matmul(grad_output)
      24. return None, grad_values, None, grad_b

      ctx还能调用很多方法

      1. class _FunctionBase(object):
      2. # no doc
      3. @classmethod
      4. def apply(cls, *args, **kwargs): # real signature unknown
      5. pass
      6. def register_hook(self, *args, **kwargs): # real signature unknown
      7. pass
      8. def _do_backward(self, *args, **kwargs): # real signature unknown
      9. pass
      10. def _do_forward(self, *args, **kwargs): # real signature unknown
      11. pass
      12. def _register_hook_dict(self, *args, **kwargs): # real signature unknown
      13. pass
      14. def __init__(self, *args, **kwargs): # real signature unknown
      15. pass
      16. @staticmethod # known case of __new__
      17. def __new__(*args, **kwargs): # real signature unknown
      18. """ Create and return a new object. See help(type) for accurate signature. """
      19. pass
      20. dirty_tensors = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      21. metadata = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      22. needs_input_grad = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      23. next_functions = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      24. non_differentiable = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      25. requires_grad = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      26. saved_tensors = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      27. saved_variables = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
      28. to_save = property(lambda self: object(), lambda self, v: None, lambda self: None) # default