本文参考

    该文章写得很棒,不过该文是基于pytorch0.4之前的源码进行分析,本文在PyTorch1.7的源码上进行分析,另外会添加UML图以便更直观地理解。

    1. [torch/autograd/function.py]
    2. class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore
    3. """Records operation history and defines formulas for differentiating ops."""
    4. @staticmethod
    5. def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    6. raise NotImplementedError("You must implement the forward function for custom"
    7. " autograd.Function.")
    8. @staticmethod
    9. def backward(ctx: Any, *grad_outputs: Any) -> Any:
    10. raise NotImplementedError("You must implement the backward function for custom"
    11. " autograd.Function.")

    with_metaclass()的引入是为了解决Python2和Python3元类语法兼容的问题【详细见Understanding the with_metaclass()How does it work - with_metaclass】,with_metaclass()torch/_six.py中定义。
    因此相当于Python3中的如下定义

    class Function(metaclass=FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)
    
    [torch/autograd/function.py]
    class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin):
        _is_legacy = False
    
        def apply(self, *args):
            # _forward_cls is defined by derived class
            return self._forward_cls.backward(self, *args)  # type: ignore
    
    class FunctionMeta(type):
        def __init__(cls, name, bases, attrs):
            ...
            backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
            cls._backward_cls = backward_fn
    
            return super(FunctionMeta, cls).__init__(name, bases, attrs)
    

    FunctionMeta继承type,作为Function的元类,元类的作用是创建类。FunctionMeta给要创建的类添加了一个属性_backward_cls,其值为backward_fn,而backward_fn又是一个用type动态生成的类,该类继承于BackwardCFunction,并且有一个属性_forward_cls,其值为cls,即FunctionMeta创建的类。

    class Exp(Function):
    
        @staticmethod
        def forward(ctx, i):
            result = i.exp()
            ctx.save_for_backward(result)
            return result
    
        @staticmethod
        def backward(ctx, grad_output):
            result, = ctx.saved_tensors
            return grad_output * result
    
    >>> Exp._backward_cls
    torch.autograd.function.ExpBackward
    >>> Exp._backward_cls()
    <torch.autograd.function.ExpBackward at 0x7fe9a4519e60>
    >>> Exp._backward_cls._forward_cls
    __main__.Exp
    >>> Exp._backward_cls._forward_cls.apply?
    Docstring: <no docstring>
    Type:      builtin_function_or_method
    >>> Exp._backward_cls.apply?
    Signature: Exp._backward_cls.apply(self, *args)
    Docstring: <no docstring>
    File:      ~/miniconda3/envs/jupyter/lib/python3.7/site-packages/torch/autograd/function.py
    Type:      function
    

    FunctionMeta在创建Exp类的时候同时创建了个ExpBackward类,可以看到,Exp和ExpBackward都继承了_C._FunctionBase, _ContextMethodMixin, _HookMixin。另外,ExpBackward继承了BackwardCFunction中的apply方法,apply方法中调用了Exp中的backward方法。

    至此,我们已经知道ExpBackward是如何调用Exp中的backward方法的,那么Exp中的forward方法是如何被调用的,答案是_C._FunctionBase中的apply方法。

    [torch/csrc/autograd/python_function.cpp]
    PyTypeObject THPFunctionType = {
      PyVarObject_HEAD_INIT(nullptr, 0)
      "torch._C._FunctionBase",                    /* tp_name */
      ...
      THPFunction_methods,                         /* tp_methods */
      nullptr,                                     /* tp_members */
      THPFunction_properties,                      /* tp_getset */
      ...
      THPFunction_new                              /* tp_new */
    };
    

    这里只关注THPFunction_methods。

    [torch/csrc/autograd/python_function.cpp]
    static struct PyMethodDef THPFunction_methods[] = {
      ...
      {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
      ...
    };
    

    THPFunction_methods中有个apply方法,指向THPFunction_apply

    [torch/csrc/autograd/python_function.cpp]
    PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
    {
      HANDLE_TH_ERRORS
      RECORD_FUNCTION(
        ((PyTypeObject*)cls)->tp_name,
        std::vector<c10::IValue>(),
        at::sequence_number::peek());
    
      THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
      if (!backward_cls) return nullptr;
      THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
      if (!ctx_obj) return nullptr;
      THPFunction* ctx = (THPFunction*)ctx_obj.get();
    
      auto cdata = std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
      ctx->cdata = cdata;
    
      // Prepare inputs and allocate context (grad fn)
      auto info_pair = unpack_input<false>(inputs);
      UnpackedInput& unpacked_input = info_pair.first;
      InputFlags& input_info = info_pair.second;
    
      // Record input nodes if tracing
      auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
    
      // Initialize backward function (and ctx)
      bool is_executable = input_info.is_executable;
      cdata->set_next_edges(std::move(input_info.next_edges));
      ctx->needs_input_grad = input_info.needs_input_grad.release();
      ctx->is_variable_input = std::move(input_info.is_variable_input);
    
      // Prepend ctx to input_tuple, in preparation for static method call
      auto num_args = PyTuple_GET_SIZE(inputs);
      THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
      if (!ctx_input_tuple) return nullptr;
      Py_INCREF(ctx);
      PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
      for (int i = 0; i < num_args; ++i) {
        PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
        Py_INCREF(arg);
        PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
      }
    
      // Call forward
      THPObjectPtr tensor_outputs;
      {
        AutoGradMode grad_mode(false);
        THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
        if (!forward_fn) return nullptr;
        tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
        if (!tensor_outputs) return nullptr;
      }
    
      return process_outputs(cls, cdata, ctx, unpacked_input, inputs, std::move(tensor_outputs),
                             is_executable, node);
      END_HANDLE_TH_ERRORS
    }
    

    THPFunction_apply按逻辑分为5块

    1. Prepare inputs and allocate context (grad fn)

      通过unpack_input函数解析参数inputs,得到unpacked_input和input_info两个对象

    2. Record input nodes if tracing

    3. Initialize backward function (and ctx)

      初始化backward函数(和ctx,ctx保存用于backward的信息)

    4. Prepend ctx to input_tuple, in preparation for static method call

      重新处理参数,把ctx放在input_tuple首位,然后把unpacked_input.input_tuple按序放在input_tuple中。

    5. Call forward

      调用forward

    我们知道PyTorch的计算图是在forward的过程中构建的,那么THPFunction_apply中的哪部分是构建计算图的代码。

    [torch/csrc/autograd/python_function.cpp]
    struct UnpackedInput {
      THPObjectPtr input_tuple;
      variable_list input_vars;
    };
    
    struct InputFlags {
      bool is_executable = false;
      edge_list next_edges;
      THPObjectPtr needs_input_grad;
      std::vector<bool> is_variable_input;
    };
    

    UnpackedInput中的input_tuple用于保存inputs解析后的数据,InputFlags类通过逐个解析_inputs的分量,来判断每个变量的求导标识。

    [torch/csrc/autograd/python_function.cpp]
    template<bool enforce_variables>
    std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
      UnpackedInput unpacked;
      InputFlags flags;
    
      auto num_args = PyTuple_GET_SIZE(args);
      unpacked.input_tuple = PyTuple_New(num_args);
      flags.needs_input_grad = PyTuple_New(num_args);
      for (int i = 0; i < num_args; i++) {
        PyObject *arg = PyTuple_GET_ITEM(args, i);
    
        bool is_variable = THPVariable_Check(arg);
        flags.is_variable_input.push_back(is_variable);
        if (!is_variable) {
          // TODO: remove this code path once Variable and Tensor are merged in Python
          if (enforce_variables) {
            THPUtils_setError("expected a Variable argument, but got %s",
                              THPUtils_typename(arg));
            throw python_error();
          }
          Py_INCREF(Py_False);
          PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
        } else {
          THPVariable* variable = (THPVariable*)arg;
          unpacked.input_vars.push_back(variable->cdata);
          PyObject* needs_grad = variable->cdata.requires_grad() ? Py_True : Py_False;
          Py_INCREF(needs_grad);
          PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
        }
        Py_INCREF(arg);
        PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
      }
    
      flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
      flags.next_edges = collect_next_edges(unpacked.input_vars);
      return std::make_pair(std::move(unpacked), std::move(flags));
    }
    

    InputFlags.needs_input_grad保存arg是否需要求导,InputFlags.is_variable_input

    flags.next_edges

    flags.next_edges = collect_next_edges(unpacked.input_vars); 是构建计算图的关键(即反向传播链式求导的关键)

    [torch/csrc/autograd/function.h]
    /// Return the next edges of all the given variables, or tuples of variables.
    template <typename... Variables>
    edge_list collect_next_edges(Variables&&... variables) {
      if (!GradMode::is_enabled())
        return {};
      detail::MakeNextFunctionList make;
      make.apply(std::forward<Variables>(variables)...);
      return std::move(make.next_edges);
    }
    

    MakeNextFunctionList

    [torch/csrc/autograd/function.h]
    struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
      edge_list next_edges;
      using IterArgs<MakeNextFunctionList>::operator();
      void operator()(const Variable& variable) {
        if (variable.defined()) {
          next_edges.push_back(impl::gradient_edge(variable));
        } else {
          next_edges.emplace_back();
        }
      }
      void operator()(const c10::optional<Variable>& variable) {
        if (variable.has_value() && variable->defined()) {
          next_edges.push_back(impl::gradient_edge(*variable));
        } else {
          next_edges.emplace_back();
        }
      }
    };
    

    gradient_edge

    [torch/csrc/autograd/variable.cpp]
    Edge gradient_edge(const Variable& self) {
        // If grad_fn is null (as is the case for a leaf node), we instead
        // interpret the gradient function to be a gradient accumulator, which will
        // accumulate its inputs into the grad property of the variable. These
        // nodes get suppressed in some situations, see "suppress gradient
        // accumulation" below. Note that only variables which have `requires_grad =
        // True` can have gradient accumulators.
        if (const auto& gradient = self.grad_fn()) {
          return Edge(gradient, self.output_nr());
        } else {
          return Edge(grad_accumulator(self), 0);
        }
      }
    

    至此,找到了要连接的边

    PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr<PyNode>& cdata,
                              THPFunction* grad_fn, const UnpackedInput& unpacked,
                              PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable,
                              torch::jit::Node* node) {
      bool unpack_output = ensure_tuple(raw_output);
    
      auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
    
      THPObjectPtr outputs(PyTuple_New(num_outputs));
      if (!outputs) throw python_error();
    
      cdata->clear_input_metadata();
    
      // Record type, device, and size information about inputs
      if (is_executable) {
        grad_fn->input_info.clear();
        grad_fn->input_info.reserve(unpacked.input_vars.size());
        for (auto& var : unpacked.input_vars) {
          grad_fn->input_info.emplace_back(var);
        }
      }
    
      bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
      _wrap_outputs(cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
      _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
      if (is_executable) {
        _save_variables(cdata, grad_fn);
      } else {
        // Remove unnecessary attributes
        Py_XDECREF(grad_fn->to_save);
        grad_fn->to_save = nullptr;
        Py_XDECREF(grad_fn->non_differentiable);
        grad_fn->non_differentiable = nullptr;
      }
    
      // Unpack the output, unless .forward() returned a tuple
      if (unpack_output) {
        PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0);
        Py_INCREF(output);
        return output;
      }
    
      return outputs.release();
    }