PyTorch
说起backward大家肯定不陌生,用过PyTorch的肯定都知道,这个函数的作用是反向传播计算梯度的。比如下边这个例子,要反向传播计算梯度之后才能调用优化器的**step**函数更新网络模型参数

  1. Example:
  2. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  3. >>> optimizer.zero_grad()
  4. >>> loss_fn(model(input), target).backward()
  5. >>> optimizer.step()

[1] torch.Tensor.backward

在 torch/tensor.py 文件中可以看到,**class Tensor(torch._C._TensorBase)**中有函数**def backward**。所以可以用**tensor.backward()**来进行反向传播。

  1. # https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html
  2. def backward(self, gradient=None, retain_graph=None, create_graph=False):
  3. r"""Computes the gradient of current tensor w.r.t. graph leaves.
  4. The graph is differentiated using the chain rule. If the tensor is
  5. non-scalar (i.e. its data has more than one element) and requires
  6. gradient, the function additionally requires specifying ``gradient``.
  7. It should be a tensor of matching type and location, that contains
  8. the gradient of the differentiated function w.r.t. ``self``.
  9. This function accumulates gradients in the leaves - you might need to
  10. zero them before calling it.
  11. Arguments:
  12. gradient (Tensor or None): Gradient w.r.t. the
  13. tensor. If it is a tensor, it will be automatically converted
  14. to a Tensor that does not require grad unless ``create_graph`` is True.
  15. None values can be specified for scalar Tensors or ones that
  16. don't require grad. If a None value would be acceptable then
  17. this argument is optional.
  18. retain_graph (bool, optional): If ``False``, the graph used to compute
  19. the grads will be freed. Note that in nearly all cases setting
  20. this option to True is not needed and often can be worked around
  21. in a much more efficient way. Defaults to the value of
  22. ``create_graph``.
  23. create_graph (bool, optional): If ``True``, graph of the derivative will
  24. be constructed, allowing to compute higher order derivative
  25. products. Defaults to ``False``.
  26. """
  27. torch.autograd.backward(self, gradient, retain_graph, create_graph)

其中,**create_graph**参数的作用是,如果为True,那么就创建一个专门的graph of the derivative,这可以方便计算高阶微分。参数**retain_graph**可以忽略,因为绝大多数情况根本不需要,它的作用是要不要保留Graph。该函数实现代码也很简单,就是调用**torch.autograd.backward**。所以接下来看一下**torch.autograd.backward**中的实现。

[2] torch.autograd.backward

函数**torch.autograd.backward**的定义在文件 torch/autograd/init.py 中。借助于链式法则the chain ruleJacobian-vector product可以很方便的计算梯度。下边就是具体的代码。

  1. # https://github.com/pytorch/pytorch/blob/master/torch/autograd/__init__.py
  2. # https://pytorch.org/docs/stable/generated/torch.autograd.backward.html
  3. # ...
  4. from .variable import Variable
  5. # ...
  6. def _make_grads(outputs, grads):
  7. new_grads = []
  8. for out, grad in zip(outputs, grads):
  9. if isinstance(grad, torch.Tensor):
  10. if not out.shape == grad.shape:
  11. # raise RuntimeError ...
  12. new_grads.append(grad)
  13. elif grad is None:
  14. if out.requires_grad:
  15. if out.numel() != 1:
  16. # raise RuntimeError ...
  17. else:
  18. new_grads.append(None)
  19. else:
  20. # raise TypeError ...
  21. return tuple(new_grads)
  22. def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
  23. r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
  24. The graph is differentiated using the chain rule. If any of ``tensors``
  25. are non-scalar (i.e. their data has more than one element) and require
  26. gradient, then the Jacobian-vector product would be computed, in this
  27. case the function additionally requires specifying ``grad_tensors``.
  28. It should be a sequence of matching length, that contains the "vector"
  29. in the Jacobian-vector product, usually the gradient of the differentiated
  30. function w.r.t. corresponding tensors (``None`` is an acceptable value for
  31. all tensors that don't need gradient tensors).
  32. This function accumulates gradients in the leaves - you might need to zero
  33. them before calling it.
  34. """
  35. if grad_variables is not None:
  36. warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
  37. if grad_tensors is None:
  38. grad_tensors = grad_variables
  39. else:
  40. raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
  41. "arguments both passed to backward(). Please only "
  42. "use 'grad_tensors'.")
  43. tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
  44. if grad_tensors is None:
  45. grad_tensors = [None] * len(tensors)
  46. elif isinstance(grad_tensors, torch.Tensor):
  47. grad_tensors = [grad_tensors]
  48. else:
  49. grad_tensors = list(grad_tensors)
  50. grad_tensors = _make_grads(tensors, grad_tensors)
  51. if retain_graph is None:
  52. retain_graph = create_graph
  53. Variable._execution_engine.run_backward(
  54. tensors, grad_tensors, retain_graph, create_graph,
  55. allow_unreachable=True) # allow_unreachable flag
  56. # ...
  57. if not torch._C._autograd_init():
  58. raise RuntimeError("autograd initialization failed")

参数**grad_variables**是老版本的,已经被**deprecated**,现在使用的是**grad_tensors**。即便使用了也没关系,代码会把参数grad_variables的值传给参数**grad_tensors**以供使用。代码中用到了函数**_make_grads**,该函数主要是对grad_tensors中的元素进行检查并且将**grad_tensors**重新组织成**tuple(list(torch.Tensor, ...))**的形式。做完这一系列操作之后就是调用**Variable._execution_engine.run_backward**,并且将这些被**check**和重新组织的参数传给该函数。注意参数**allow_unreachable**,下边还会遇到。

[3] Variable._execution_engine.run_backward

从文件中的代码**from .variable import Variable**可以知道,**Variable**的定义在文件 torch/autograd/variable.py 中。具体代码如下。

  1. # https://github.com/pytorch/pytorch/blob/master/torch/autograd/variable.py
  2. import torch
  3. from torch._six import with_metaclass
  4. class VariableMeta(type):
  5. def __instancecheck__(cls, other):
  6. return isinstance(other, torch.Tensor)
  7. class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)):
  8. pass
  9. from torch._C import _ImperativeEngine as ImperativeEngine
  10. Variable._execution_engine = ImperativeEngine()
  11. # https://github.com/pytorch/pytorch/tree/master/torch/csrc

代码内容很短,可以看到,前边看到的函数**Variable._execution_engine.run_backward**其实就是torch._C中的函数**_ImperativeEngine**torch._C这个是调用的被编译之后的C++代码,Windows系统下可以在Python目录\Lib\site-packages\torch下找到_C.cp35-win_amd64.pyd这个文件,当然不同的Python版本名称也会略有不同,但是这个_C.pyd是一样的。具体的函数实现代码可以从GitHub上 pytorch/torch/csrc 这里找到。

[4] torch._C._ImperativeEngine

很容易就可以找到,函数_ImperativeEngine在文件 torch/csrc/autograd/python_engine.cpp 中的第 308 行出现。代码如下。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_engine.cpp#L308
  2. bool THPEngine_initModule(PyObject *module)
  3. {
  4. #ifndef _WIN32
  5. if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
  6. throw std::runtime_error("unable to set pthread_atfork handler");
  7. }
  8. #endif
  9. if (PyType_Ready(&THPEngineType) < 0)
  10. return false;
  11. Py_INCREF(&THPEngineType);
  12. PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
  13. set_default_engine_stub(get_python_engine);
  14. return true;
  15. }
  16. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp%23L679

通过函数**PyModule_AddObject****(PyObject *)&THPEngineType**这个object加入到模块module中并命名为**_ImperativeEngine**。这个**module**的类型是**PyObject**,这个初始化函数可以在文件 torch/csrc/Module.cpp() 的第 679 行找到,**module**的定义则是在第 67 行。
关于函数**PyModule_AddObject**的详细介绍可以参考https://docs.python.org/3.5/c-api/module.html
另外关于 Python 扩展的相关知识,可以参考https://docs.python.org/3.5/extending/index.html
现在回过头来看之前的Variable._execution_engine.run_backward()其实就是_ImperativeEngine().run_backward()。从对象THPEngineType的定义可以找到run_backward也只是个外套,具体的C++函数其实是THPEngine_run_backward。这部分代码仍然是在 torch/csrc/autograd/python_engine.cpp 中。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_engine.cpp#L308
  2. // ...
  3. static struct PyMethodDef THPEngine_methods[] = {
  4. {(char*)"run_backward", (PyCFunction)(void(*)(void))THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
  5. {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
  6. {(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr},
  7. {nullptr}
  8. };
  9. // ...
  10. PyTypeObject THPEngineType = {
  11. // ...
  12. "torch._C._EngineBase", /* tp_name */
  13. // ...
  14. THPEngine_methods, /* tp_methods */
  15. // ...
  16. };
  17. // ...

代码中使用了**PyMethodDef**,该函数是用于描述扩展方法的**struct**。可以看到,除了要找的函数**run_backward**,此处还定义了函数**queue_callback**和函数**is_checkpoint_valid**

[5] THPEngine_run_backward

关于函数**THPEngine_run_backward**的介绍是**Implementation of torch._C._EngineBase.run_backward**,而**torch._C._EngineBase**这个名字在**THPEngineType**的定义部分的代码可以找到。该部分代码超过一百行了,下边分块来看一下。
首先把中间部分略去。函数内第一行和最后一行的**HANDLE_TH_ERRORS****END_HANDLE_TH_ERRORS**,是在文件 torch/csrc/Exceptions.h 中定义的宏,具体地,分别在第 41 行和第 114 行被定义。这部分代码主要是通过函数**PyArg_ParseTupleAndKeywords**对输入的参数重新解析并赋值给新定义的变量**tensors****grad_tensors****keep_graph****create_graph****inputs**以及**allow_unreachable**
有关函数**PyArg_ParseTupleAndKeywords**的用法详见 https://docs.python.org/3.5/c-api/arg.html#c.PyArg_ParseTupleAndKeywords

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Exceptions.h
  2. // Implementation of torch._C._EngineBase.run_backward
  3. PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
  4. {
  5. HANDLE_TH_ERRORS
  6. _maybe_reinitialize_engine_after_fork();
  7. PyObject *tensors = nullptr;
  8. PyObject *grad_tensors = nullptr;
  9. unsigned char keep_graph = 0;
  10. unsigned char create_graph = 0;
  11. PyObject *inputs = nullptr;
  12. unsigned char allow_unreachable = 0;
  13. const char *accepted_kwargs[] = {
  14. "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
  15. "allow_unreachable", nullptr
  16. };
  17. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
  18. &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
  19. return nullptr;
  20. // ...
  21. END_HANDLE_TH_ERRORS
  22. }

下面来看下中间的部分。这部分主要是**Check**一下**tensors****grad_tensors**的变量类型,并且检查二者的**tuple size**是否一致。

  1. // Implementation of torch._C._EngineBase.run_backward
  2. PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
  3. {
  4. // ...
  5. THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
  6. "be a tuple, but got %s", THPUtils_typename(tensors));
  7. THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
  8. "expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));
  9. Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
  10. Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
  11. THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
  12. "gradients", num_tensors, num_gradients);
  13. // ...
  14. }

下边这部分代码也比较简单。先是定义**edge_list roots;****variable_list grads;**。接下来通过循环把**tensors****grad_tensors**中的元素**push_back****roots****grads**。具体地,先通过**PyTuple_GET_ITEM**取出元素,再利用**((THPVariable*)···)->cdata**取出元素的值。当然中间也会做一些**Check**,例如是否为**Tensor**之类的。

  1. // Implementation of torch._C._EngineBase.run_backward
  2. PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
  3. {
  4. // ...
  5. edge_list roots;
  6. roots.reserve(num_tensors);
  7. variable_list grads;
  8. grads.reserve(num_tensors);
  9. for (int i = 0; i < num_tensors; i++) {
  10. PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
  11. THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
  12. "tuple is not a Tensor", i);
  13. auto& variable = ((THPVariable*)_tensor)->cdata;
  14. auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
  15. THPUtils_assert(gradient_edge.function,
  16. "element %d of tensors does not require grad and does not have a grad_fn", i);
  17. roots.push_back(std::move(gradient_edge));
  18. PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
  19. if (THPVariable_Check(grad)) {
  20. const Variable& grad_var = ((THPVariable*)grad)->cdata;
  21. if (grad_var.has_names()) {
  22. TORCH_WARN(
  23. "Autograd was passed a named grad tensor with dims ", grad_var.names(),
  24. ". Autograd does not yet support named tensor semantics, so all names ",
  25. "will be ignored. In practice all computed gradients will still be correct "
  26. "according to regular tensor semantics.");
  27. }
  28. grads.push_back(grad_var);
  29. } else {
  30. THPUtils_assert(grad == Py_None,
  31. "element %d of gradients tuple is not a Tensor or None", i);
  32. THPUtils_assert(!variable.requires_grad(),
  33. "element %d of gradients tuple is None, but the corresponding Tensor requires grad");
  34. }
  35. }
  36. // ...
  37. }

下边继续看。这部分代码就是对**inputs**中的每一个元素都传入函数**torch::autograd::impl::try_get_grad_accumulator**中去处理。函数**try_get_grad_accumulator**被定义在文件 torch/csrc/autograd/variable.h 的第 113 行,具体实现则是在文件 torch/csrc/autograd/variable.cpp 的第111 行,这个等下再说,现在只需要知道返回的是个指向**Node**对象的指针。接下来就是,如果指针不是空指针,则执行**output_edges.emplace_back(grad_fn, output_nr)**
函数**push_back()****emplace_back()**的区别是,**push_back()**函数向容器中加入一个临时对象(右值元素)时, 首先会调用构造函数生成这个对象,然后条用拷贝构造函数将这个对象放入容器中, 最后释放临时对象。但是emplace_back()函数向容器中中加入临时对象, 临时对象原地构造,没有赋值或移动的操作。详细内容参阅 cpp/container/vector/emplace_back 。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/variable.h#L113
  2. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/variable.cpp#L111
  3. // https://en.cppreference.com/w/cpp/container/vector/emplace_back
  4. // Implementation of torch._C._EngineBase.run_backward
  5. PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
  6. {
  7. // ...
  8. std::vector<Edge> output_edges;
  9. if (inputs != nullptr) {
  10. int num_inputs = PyTuple_GET_SIZE(inputs);
  11. output_edges.reserve(num_inputs);
  12. for (int i = 0; i < num_inputs; ++i) {
  13. PyObject *input = PyTuple_GET_ITEM(inputs, i);
  14. THPUtils_assert(THPVariable_Check(input),
  15. "all inputs have to be Tensors, but got %s", THPUtils_typename(input));
  16. THPVariable *input_var = (THPVariable*)input;
  17. const auto output_nr = input_var->cdata.output_nr();
  18. auto grad_fn = input_var->cdata.grad_fn();
  19. if (!grad_fn) {
  20. grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata);
  21. }
  22. THPUtils_assert(input_var->cdata.requires_grad(),
  23. "One of the differentiated Tensors does not require grad");
  24. if (!grad_fn) {
  25. output_edges.emplace_back();
  26. } else {
  27. output_edges.emplace_back(grad_fn, output_nr);
  28. }
  29. }
  30. }
  31. // ...
  32. }

现在来看看传入**output_edges**的这两个参数都是什么类型。**grad_fn**是指向Node对象的std::shared_ptr指针,现在来看看另外一个参数**output_nr**。结构体**THPVariable**被定义在文件 torch/csrc/autograd/pythonvariable.h 中,代码如下所示。可以看到其中**cdata**变量的类型是**torch::autograd::Variable**。最终在 torch/csrc/autograd/VariableTypeManual.cpp 找到函数**output_nr**,其返回的是文件 torch/csrc/autograd/variable.h 中定义的结构体**AutogradMeta**中的成员变量`**uint32_t output_nr;**`,这和文件 torch/csrc/autograd/edge.h 中定义的结构体Edge初始化的参数类型刚好吻合。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_variable.h#L12
  2. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/VariableTypeManual.cpp#L134
  3. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/variable.h#L179
  4. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/edge.h#L14
  5. // python_variable.h
  6. // Python object that backs torch.autograd.Variable
  7. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  8. struct THPVariable {
  9. PyObject_HEAD
  10. // Payload
  11. torch::autograd::Variable cdata;
  12. // Hooks to be run on backwards pass (corresponds to Python attr
  13. // '_backwards_hooks', set by 'register_hook')
  14. PyObject* backward_hooks = nullptr;
  15. };
  16. // =======================================================
  17. // VariableTypeManual.cpp
  18. int64_t output_nr(const Tensor & self) {
  19. if (impl::get_autograd_meta(self)) {
  20. return impl::get_autograd_meta(self)->output_nr_;
  21. } else {
  22. return 0;
  23. }
  24. }
  25. // =======================================================
  26. // variable.h
  27. struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
  28. // ...
  29. // The "output number" of this variable; e.g., if this variable
  30. // was the second output of a function, then output_nr == 1.
  31. // We use this to make sure we can setup the backwards trace
  32. // correctly when this variable is passed to another function.
  33. uint32_t output_nr_;
  34. // ...
  35. };
  36. // =======================================================
  37. // edge.h
  38. /// Represents a particular input of a function.
  39. struct Edge {
  40. Edge() noexcept : function(nullptr), input_nr(0) {}
  41. Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
  42. : function(std::move(function_)), input_nr(input_nr_) {}
  43. // ...
  44. };

该函数最后这部分的代码如下所示。注意到**THPUtils_assert(allow_unreachable ... );**,其中**allow_unreachable flag**可以追溯到上边第二部分的源码,通过**Variable._execution_engine.run_backward**传入的是**allow_unreachable=True****PyTuple_GET_SIZE**的作用是获取传入参数的size**PyTuple_New**的作用是创建一个新的**tuple**对象,传入的参数就是新的**tuple**对象的**size****PyTuple_SET_ITEM**的作用是将**THPVariable_Wrap(outputs[i])**传入到**py_outputs.get()**的位置i处。这里最关键的就是函数**engine.execute**,下边具体介绍。

  1. // Implementation of torch._C._EngineBase.run_backward
  2. PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
  3. {
  4. // ...
  5. variable_list outputs;
  6. {
  7. pybind11::gil_scoped_release no_gil;
  8. outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
  9. }
  10. if (inputs != nullptr) {
  11. int num_inputs = PyTuple_GET_SIZE(inputs);
  12. THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
  13. if (!py_outputs) return nullptr;
  14. for (int i = 0; i < num_inputs; i++) {
  15. THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
  16. "differentiated Tensors appears to not have been used "
  17. "in the graph. Set allow_unused=True if this is the "
  18. "desired behavior.");
  19. PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
  20. }
  21. return py_outputs.release();
  22. } else {
  23. Py_RETURN_NONE;
  24. }
  25. // ...
  26. }

**PyTuple_GET_SIZE****PyTuple_New**以及**PyTuple_SET_ITEM**这样的函数,用处就是可以在C++中操纵Python对象。关于**tuple**的类似的函数可以查阅 https://docs.python.org/3.5/c-api/tuple.html 。这些其实也很好记,函数名带有PyTuple的就是PythonTuple对象,带有PyList的就是PythonList对象,带有PyType的就是PythonType对象。更多内容可以去看一下 https://docs.python.org/3.5/c-api/concrete.html

[6] try_get_grad_accumulator

现在回头来看一下函数**try_get_grad_accumulator**,定义在文件 torch/csrc/autograd/variable.h 的第 113 行,具体实现则是在文件 torch/csrc/autograd/variable.cpp 的第111 行。源码简化之后,如下所示。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/variable.h#L113
  2. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/variable.cpp#L111
  3. // variable.h
  4. // ...
  5. namespace torch { namespace autograd {
  6. struct Node;
  7. struct AutogradMeta;
  8. struct DifferentiableViewMeta;
  9. using Variable = at::Tensor;
  10. namespace impl {
  11. // ...
  12. TORCH_API AutogradMeta* get_autograd_meta(const Variable&);
  13. // ...
  14. TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
  15. // ...
  16. }
  17. // ...
  18. struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
  19. // ...
  20. std::weak_ptr<Node> grad_accumulator_;
  21. // ...
  22. };
  23. // =================================
  24. // variable.cpp
  25. // ...
  26. namespace torch {
  27. namespace autograd {
  28. // ...
  29. namespace impl {
  30. // ...
  31. std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
  32. if (get_autograd_meta(self)) {
  33. return get_autograd_meta(self)->grad_accumulator_.lock();
  34. } else {
  35. return nullptr;
  36. }
  37. }
  38. // ...
  39. AutogradMeta* get_autograd_meta(const Variable& self) {
  40. // NB: could return null
  41. TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor");
  42. return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta());
  43. }
  44. // ...
  45. }
  46. // ...

所以函数**try_get_grad_accumulator**就是先通过函数**get_autograd_meta**返回一个**AutogradMeta**结构体,然后访问结构体中的成员变量**grad_accumulator_**,而**grad_accumulator_**是一个指向类型为Node对象的**std::weak_ptr**指针。**lock()**函数的作用是创建一个**std::shared_ptr**来管理对象,**try_get_grad_accumulator**函数的返回类型是std::shared_ptr
**weak_ptr**设计的目的是为配合**shared_ptr**而引入的一种智能指针,详见 https://en.cppreference.com/w/cpp/memory/weak_ptrhttps://en.cppreference.com/w/cpp/memory/weak_ptr/lock

[7] engine.execute(roots, grads, keep_graph, create_graph, output_edges)

接着上边第 5 部分继续来看,最重要的**variable_list outputs;**的值是由函数**engine.execute**得到的。**engine**的定义如下,在文件 torch/csrc/autograd/python_engine.cpp 的第 26 行。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_engine.cpp#L26
  2. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_engine.h#L12
  3. static torch::autograd::python::PythonEngine engine;

**torch::autograd::python::PythonEngine**的定义在文件 torch/csrc/autograd/python_engine.h 中,代码如下所示。结构体**PythonEngine**继承自结构体**Engine**,而其中的方法**execute**也是重载的**Engine::execute**函数,所以要讨论的函数就变成了**Engine::execute(roots, inputs, keep_graph, create_graph, outputs)**

  1. // python_engine.h
  2. namespace torch { namespace autograd { namespace python {
  3. struct PythonEngine : public Engine {
  4. void thread_init(int device) override;
  5. void thread_on_exception(
  6. std::shared_ptr<GraphTask>& graph_task,
  7. const std::shared_ptr<Node>& fn,
  8. std::exception& e) override;
  9. variable_list execute(
  10. const edge_list& roots,
  11. const variable_list& inputs,
  12. bool keep_graph,
  13. bool create_graph,
  14. const edge_list& outputs = {}) override;
  15. variable_list execute_with_graph_task(
  16. std::shared_ptr<GraphTask> graph_task,
  17. std::shared_ptr<Node> graph_root) override;
  18. std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override;
  19. };
  20. }}} // namespace torch::autograd::python
  21. // ================================================
  22. // python_engine.cpp
  23. variable_list PythonEngine::execute(
  24. const edge_list& roots,
  25. const variable_list& inputs,
  26. bool keep_graph,
  27. bool create_graph,
  28. const edge_list& outputs) {
  29. try {
  30. return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
  31. } catch (python_error& e) {
  32. e.restore();
  33. throw;
  34. }
  35. }

下边来看下文件 torch/csrc/autograd/engine.h 和 torch/csrc/autograd/engine.cpp 。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.h
  2. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.cpp
  3. variable_list Engine::execute(const edge_list& roots,
  4. const variable_list& inputs,
  5. bool keep_graph,
  6. bool create_graph,
  7. const edge_list& outputs)
  8. // 调用 ↓
  9. variable_list Engine::execute_with_graph_task(std::shared_ptr<GraphTask> graph_task,
  10. std::shared_ptr<Node> graph_root)
  11. // 调用 ↓
  12. void Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task,
  13. bool reentrant_thread)
  14. // 调用 ↓
  15. void Engine::evaluate_function(std::shared_ptr<GraphTask>& graph_task,
  16. Node* func,
  17. InputBuffer& inputs)
  18. // 调用 ↓ (这个函数不是 Engine 结构体中的方法)
  19. variable_list call_function(std::shared_ptr<GraphTask>& graph_task,
  20. Node* func,
  21. InputBuffer& inputBuffer) {
  22. // ...
  23. auto& fn = *func;
  24. auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
  25. variable_list outputs = fn(std::move(inputs));
  26. // ...
  27. if(has_post_hooks){
  28. // NOLINTNEXTLINE(bugprone-use-after-move)
  29. return call_post_hooks(fn, std::move(outputs), inputs);
  30. }
  31. return outputs;
  32. }
  33. // =================================
  34. static variable_list call_pre_hooks(Node& fn,
  35. variable_list inputs) {
  36. for (const auto& hook : fn.pre_hooks()) {
  37. inputs = (*hook)(inputs);
  38. }
  39. return inputs;
  40. }
  41. static variable_list call_post_hooks(Node& fn,
  42. variable_list outputs,
  43. const variable_list& inputs) {
  44. for (const auto& hook : fn.post_hooks()) {
  45. outputs = (*hook)(outputs, inputs);
  46. }
  47. return outputs;
  48. }
  49. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/function.h#L87
  50. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.h#L38

下边来整理下这些函数里遇到的结构体。首先是结构体Node,被定义在文件 torch/csrc/autograd/function.h 的第 87 行。关于Node,其表示一个操作,可以理解成Autograd Graph中的顶点vertice。结构体GraphTask被定义在文件 torch/csrc/autograd/engine.h 的第 38 行,其作用是GraphTask holds metadata needed for a single execution of backward()
函数**fn.post_hooks()****fn.pre_hooks()**分别返回结构体成员变量**post_hooks_****pre_hooks_**,二者类型分别为**std::vector<std::unique_ptr<FunctionPostHook>>****std::vector<std::unique_ptr<FunctionPreHook>>**。这里又涉及到了一个结构体struct FunctionPreHook。关于指针**unique_ptr**,与**shared_ptr**不同,某个时刻只能有一个**unique_ptr**指向一个给定的对象;当**unique_ptr**被销毁时,它所指向的对象也被销毁,**uniptr_ptr**表达的是一种独占的思想。说回结构体FunctionPreHookFunctionPostHook,这两个结构体都被定义在文件 torch/csrc/autograd/function_hook.h 中。

  1. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/function_hook.h
  2. #pragma once
  3. #include <vector>
  4. #include <torch/csrc/WindowsTorchApiMacro.h>
  5. #include <ATen/Tensor.h>
  6. // A hook that's called on gradients
  7. namespace torch { namespace autograd {
  8. using Variable = at::Tensor;
  9. using variable_list = std::vector<Variable>;
  10. struct TORCH_API FunctionPreHook {
  11. virtual ~FunctionPreHook();
  12. virtual variable_list operator()(const variable_list& grads) = 0;
  13. };
  14. struct TORCH_API FunctionPostHook {
  15. virtual ~FunctionPostHook();
  16. virtual variable_list operator()(
  17. const variable_list& outputs /* grad_inputs */,
  18. const variable_list& inputs /* grad_outputs */) = 0;
  19. };
  20. }} // namespace torch::autograd