本文参考
- https://www.cnblogs.com/catnip/p/8760780.html
- https://blog.csdn.net/u012436149/article/details/78510945
该文章写得很棒,不过该文是基于pytorch0.4之前的源码进行分析,本文在PyTorch1.7的源码上进行分析,另外会添加UML图以便更直观地理解。
[torch/autograd/function.py]
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore
"""Records operation history and defines formulas for differentiating ops."""
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError("You must implement the forward function for custom"
" autograd.Function.")
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError("You must implement the backward function for custom"
" autograd.Function.")
with_metaclass()
的引入是为了解决Python2和Python3元类语法兼容的问题【详细见Understanding the with_metaclass(),How does it work - with_metaclass】,with_metaclass()
在torch/_six.py
中定义。
因此相当于Python3中的如下定义
class Function(metaclass=FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)
[torch/autograd/function.py]
class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin):
_is_legacy = False
def apply(self, *args):
# _forward_cls is defined by derived class
return self._forward_cls.backward(self, *args) # type: ignore
class FunctionMeta(type):
def __init__(cls, name, bases, attrs):
...
backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
cls._backward_cls = backward_fn
return super(FunctionMeta, cls).__init__(name, bases, attrs)
FunctionMeta继承type,作为Function的元类,元类的作用是创建类。FunctionMeta给要创建的类添加了一个属性_backward_cls,其值为backward_fn,而backward_fn又是一个用type动态生成的类,该类继承于BackwardCFunction,并且有一个属性_forward_cls,其值为cls,即FunctionMeta创建的类。
class Exp(Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
>>> Exp._backward_cls
torch.autograd.function.ExpBackward
>>> Exp._backward_cls()
<torch.autograd.function.ExpBackward at 0x7fe9a4519e60>
>>> Exp._backward_cls._forward_cls
__main__.Exp
>>> Exp._backward_cls._forward_cls.apply?
Docstring: <no docstring>
Type: builtin_function_or_method
>>> Exp._backward_cls.apply?
Signature: Exp._backward_cls.apply(self, *args)
Docstring: <no docstring>
File: ~/miniconda3/envs/jupyter/lib/python3.7/site-packages/torch/autograd/function.py
Type: function
FunctionMeta在创建Exp类的时候同时创建了个ExpBackward类,可以看到,Exp和ExpBackward都继承了_C._FunctionBase, _ContextMethodMixin, _HookMixin。另外,ExpBackward继承了BackwardCFunction中的apply方法,apply方法中调用了Exp中的backward方法。
至此,我们已经知道ExpBackward是如何调用Exp中的backward方法的,那么Exp中的forward方法是如何被调用的,答案是_C._FunctionBase中的apply方法。
[torch/csrc/autograd/python_function.cpp]
PyTypeObject THPFunctionType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._FunctionBase", /* tp_name */
...
THPFunction_methods, /* tp_methods */
nullptr, /* tp_members */
THPFunction_properties, /* tp_getset */
...
THPFunction_new /* tp_new */
};
这里只关注THPFunction_methods。
[torch/csrc/autograd/python_function.cpp]
static struct PyMethodDef THPFunction_methods[] = {
...
{(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
...
};
THPFunction_methods中有个apply方法,指向THPFunction_apply
[torch/csrc/autograd/python_function.cpp]
PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
{
HANDLE_TH_ERRORS
RECORD_FUNCTION(
((PyTypeObject*)cls)->tp_name,
std::vector<c10::IValue>(),
at::sequence_number::peek());
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
if (!backward_cls) return nullptr;
THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
if (!ctx_obj) return nullptr;
THPFunction* ctx = (THPFunction*)ctx_obj.get();
auto cdata = std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
ctx->cdata = cdata;
// Prepare inputs and allocate context (grad fn)
auto info_pair = unpack_input<false>(inputs);
UnpackedInput& unpacked_input = info_pair.first;
InputFlags& input_info = info_pair.second;
// Record input nodes if tracing
auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
// Initialize backward function (and ctx)
bool is_executable = input_info.is_executable;
cdata->set_next_edges(std::move(input_info.next_edges));
ctx->needs_input_grad = input_info.needs_input_grad.release();
ctx->is_variable_input = std::move(input_info.is_variable_input);
// Prepend ctx to input_tuple, in preparation for static method call
auto num_args = PyTuple_GET_SIZE(inputs);
THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
if (!ctx_input_tuple) return nullptr;
Py_INCREF(ctx);
PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
for (int i = 0; i < num_args; ++i) {
PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
Py_INCREF(arg);
PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
}
// Call forward
THPObjectPtr tensor_outputs;
{
AutoGradMode grad_mode(false);
THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
if (!forward_fn) return nullptr;
tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
if (!tensor_outputs) return nullptr;
}
return process_outputs(cls, cdata, ctx, unpacked_input, inputs, std::move(tensor_outputs),
is_executable, node);
END_HANDLE_TH_ERRORS
}
THPFunction_apply按逻辑分为5块
Prepare inputs and allocate context (grad fn)
通过unpack_input函数解析参数inputs,得到unpacked_input和input_info两个对象
Record input nodes if tracing
Initialize backward function (and ctx)
初始化backward函数(和ctx,ctx保存用于backward的信息)
Prepend ctx to input_tuple, in preparation for static method call
重新处理参数,把ctx放在input_tuple首位,然后把unpacked_input.input_tuple按序放在input_tuple中。
Call forward
调用forward
我们知道PyTorch的计算图是在forward的过程中构建的,那么THPFunction_apply中的哪部分是构建计算图的代码。
[torch/csrc/autograd/python_function.cpp]
struct UnpackedInput {
THPObjectPtr input_tuple;
variable_list input_vars;
};
struct InputFlags {
bool is_executable = false;
edge_list next_edges;
THPObjectPtr needs_input_grad;
std::vector<bool> is_variable_input;
};
UnpackedInput中的input_tuple用于保存inputs解析后的数据,InputFlags类通过逐个解析_inputs的分量,来判断每个变量的求导标识。
[torch/csrc/autograd/python_function.cpp]
template<bool enforce_variables>
std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
UnpackedInput unpacked;
InputFlags flags;
auto num_args = PyTuple_GET_SIZE(args);
unpacked.input_tuple = PyTuple_New(num_args);
flags.needs_input_grad = PyTuple_New(num_args);
for (int i = 0; i < num_args; i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
bool is_variable = THPVariable_Check(arg);
flags.is_variable_input.push_back(is_variable);
if (!is_variable) {
// TODO: remove this code path once Variable and Tensor are merged in Python
if (enforce_variables) {
THPUtils_setError("expected a Variable argument, but got %s",
THPUtils_typename(arg));
throw python_error();
}
Py_INCREF(Py_False);
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
} else {
THPVariable* variable = (THPVariable*)arg;
unpacked.input_vars.push_back(variable->cdata);
PyObject* needs_grad = variable->cdata.requires_grad() ? Py_True : Py_False;
Py_INCREF(needs_grad);
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
}
Py_INCREF(arg);
PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
}
flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
flags.next_edges = collect_next_edges(unpacked.input_vars);
return std::make_pair(std::move(unpacked), std::move(flags));
}
InputFlags.needs_input_grad保存arg是否需要求导,InputFlags.is_variable_input
flags.next_edges
flags.next_edges = collect_next_edges(unpacked.input_vars);
是构建计算图的关键(即反向传播链式求导的关键)
[torch/csrc/autograd/function.h]
/// Return the next edges of all the given variables, or tuples of variables.
template <typename... Variables>
edge_list collect_next_edges(Variables&&... variables) {
if (!GradMode::is_enabled())
return {};
detail::MakeNextFunctionList make;
make.apply(std::forward<Variables>(variables)...);
return std::move(make.next_edges);
}
MakeNextFunctionList
[torch/csrc/autograd/function.h]
struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
edge_list next_edges;
using IterArgs<MakeNextFunctionList>::operator();
void operator()(const Variable& variable) {
if (variable.defined()) {
next_edges.push_back(impl::gradient_edge(variable));
} else {
next_edges.emplace_back();
}
}
void operator()(const c10::optional<Variable>& variable) {
if (variable.has_value() && variable->defined()) {
next_edges.push_back(impl::gradient_edge(*variable));
} else {
next_edges.emplace_back();
}
}
};
gradient_edge
[torch/csrc/autograd/variable.cpp]
Edge gradient_edge(const Variable& self) {
// If grad_fn is null (as is the case for a leaf node), we instead
// interpret the gradient function to be a gradient accumulator, which will
// accumulate its inputs into the grad property of the variable. These
// nodes get suppressed in some situations, see "suppress gradient
// accumulation" below. Note that only variables which have `requires_grad =
// True` can have gradient accumulators.
if (const auto& gradient = self.grad_fn()) {
return Edge(gradient, self.output_nr());
} else {
return Edge(grad_accumulator(self), 0);
}
}
至此,找到了要连接的边
PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr<PyNode>& cdata,
THPFunction* grad_fn, const UnpackedInput& unpacked,
PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable,
torch::jit::Node* node) {
bool unpack_output = ensure_tuple(raw_output);
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
THPObjectPtr outputs(PyTuple_New(num_outputs));
if (!outputs) throw python_error();
cdata->clear_input_metadata();
// Record type, device, and size information about inputs
if (is_executable) {
grad_fn->input_info.clear();
grad_fn->input_info.reserve(unpacked.input_vars.size());
for (auto& var : unpacked.input_vars) {
grad_fn->input_info.emplace_back(var);
}
}
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
_wrap_outputs(cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
_trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
if (is_executable) {
_save_variables(cdata, grad_fn);
} else {
// Remove unnecessary attributes
Py_XDECREF(grad_fn->to_save);
grad_fn->to_save = nullptr;
Py_XDECREF(grad_fn->non_differentiable);
grad_fn->non_differentiable = nullptr;
}
// Unpack the output, unless .forward() returned a tuple
if (unpack_output) {
PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0);
Py_INCREF(output);
return output;
}
return outputs.release();
}