【PyTorch】聊聊 backward 背后的代码

【PyTorch】聊聊 backward 背后的代码

说起backward大家肯定不陌生,用过PyTorch的肯定都知道,这个函数的作用是反向传播计算梯度的。比如下边这个例子,要反向传播计算梯度之后,才能调用优化器的step函数更新网络模型参数。

Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step() 

[1] torch.Tensor.backward

torch/tensor.py 文件中可以看到,class Tensor(torch._C._TensorBase)中有函数def backward。所以我们可以用tensor.backward()来进行反向传播。

def backward(self, gradient=None, retain_graph=None, create_graph=False):
    r"""Computes the gradient of current tensor w.r.t. graph leaves.

    The graph is differentiated using the chain rule. If the tensor is
    non-scalar (i.e. its data has more than one element) and requires
    gradient, the function additionally requires specifying ``gradient``.
    It should be a tensor of matching type and location, that contains
    the gradient of the differentiated function w.r.t. ``self``.

    This function accumulates gradients in the leaves - you might need to
    zero them before calling it.

    Arguments:
        gradient (Tensor or None): Gradient w.r.t. the
            tensor. If it is a tensor, it will be automatically converted
            to a Tensor that does not require grad unless ``create_graph`` is True.
            None values can be specified for scalar Tensors or ones that
            don't require grad. If a None value would be acceptable then
            this argument is optional.
        retain_graph (bool, optional): If ``False``, the graph used to compute
            the grads will be freed. Note that in nearly all cases setting
            this option to True is not needed and often can be worked around
            in a much more efficient way. Defaults to the value of
            ``create_graph``.
        create_graph (bool, optional): If ``True``, graph of the derivative will
            be constructed, allowing to compute higher order derivative
            products. Defaults to ``False``.
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph) 

其中,create_graph参数的作用是,如果为True,那么就创建一个专门的graph of the derivative,这可以方便计算高阶微分。参数retain_graph可以忽略,因为绝大多数情况根本不需要,它的作用是要不要保留Graph。该函数实现代码也很简单,就是调用torch.autograd.backward。所以接下来看一下torch.autograd.backward中的实现。

[2] torch.autograd.backward

函数torch.autograd.backward的定义在文件 torch/autograd/__init__.py 中。借助于链式法则the chain ruleJacobian-vector product可以很方便的计算梯度。下边就是具体的代码。

# ...

from .variable import Variable

# ... 

def _make_grads(outputs, grads):
    new_grads = []
    for out, grad in zip(outputs, grads):
        if isinstance(grad, torch.Tensor):
            if not out.shape == grad.shape:
                # raise RuntimeError ...
            new_grads.append(grad)
        elif grad is None:
            if out.requires_grad:
                if out.numel() != 1:
                    # raise RuntimeError ...
            else:
                new_grads.append(None)
        else:
            # raise TypeError ...
    return tuple(new_grads)


def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
    r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.

    The graph is differentiated using the chain rule. If any of ``tensors``
    are non-scalar (i.e. their data has more than one element) and require
    gradient, then the Jacobian-vector product would be computed, in this
    case the function additionally requires specifying ``grad_tensors``.
    It should be a sequence of matching length, that contains the "vector"
    in the Jacobian-vector product, usually the gradient of the differentiated
    function w.r.t. corresponding tensors (``None`` is an acceptable value for
    all tensors that don't need gradient tensors).

    This function accumulates gradients in the leaves - you might need to zero
    them before calling it.
    """
    if grad_variables is not None:
        warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
        if grad_tensors is None:
            grad_tensors = grad_variables
        else:
            raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                               "arguments both passed to backward(). Please only "
                               "use 'grad_tensors'.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)

    if grad_tensors is None:
        grad_tensors = [None] * len(tensors)
    elif isinstance(grad_tensors, torch.Tensor):
        grad_tensors = [grad_tensors]
    else:
        grad_tensors = list(grad_tensors)

    grad_tensors = _make_grads(tensors, grad_tensors)
    if retain_graph is None:
        retain_graph = create_graph

    Variable._execution_engine.run_backward(
        tensors, grad_tensors, retain_graph, create_graph,
        allow_unreachable=True)  # allow_unreachable flag

# ...

if not torch._C._autograd_init():
    raise RuntimeError("autograd initialization failed")

参数grad_variables是老版本的,已经被deprecated,现在使用的是grad_tensors。即便你使用了也没关系,代码会把参数grad_variables的值传给参数grad_tensors以供使用。代码中用到了函数_make_grads,该函数主要是对grad_tensors中的元素进行检查并且将grad_tensors重新组织成tuple(list(torch.Tensor, ...))的形式。做完这一系列操作之后就是调用Variable._execution_engine.run_backward,并且将这些被check和重新组织的参数传给该函数。注意参数allow_unreachable,下边还会遇到。

[3] Variable._execution_engine.run_backward

从文件中的代码from .variable import Variable可以知道,Variable的定义在文件 torch/autograd/variable.py 中。具体代码如下。

import torch
from torch._six import with_metaclass


class VariableMeta(type):
    def __instancecheck__(cls, other):
        return isinstance(other, torch.Tensor)


class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)):
    pass


from torch._C import _ImperativeEngine as ImperativeEngine
Variable._execution_engine = ImperativeEngine()

代码内容很短,可以看到,前边看到的函数Variable._execution_engine.run_backward其实就是torch._C中的函数_ImperativeEnginetorch._C这个是调用的被编译之后的C++代码,Windows系统下可以在Python目录\Lib\site-packages\torch下找到_C.cp35-win_amd64.pyd这个文件,当然不同的Python版本名称也会略有不同,但是这个_C.pyd是一样的。具体的函数实现代码我们可以从GitHubpytorch/torch/csrc 这里找到。

[4] torch._C._ImperativeEngine

很容易就可以找到,函数_ImperativeEngine在文件 torch/csrc/autograd/python_engine.cpp 中的第 308 行出现。代码如下。

bool THPEngine_initModule(PyObject *module)
{
#ifndef _WIN32
  if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
    throw std::runtime_error("unable to set pthread_atfork handler");
  }
#endif
  if (PyType_Ready(&THPEngineType) < 0)
    return false;
  Py_INCREF(&THPEngineType);
  PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
  set_default_engine_stub(get_python_engine);
  return true;
}

通过函数PyModule_AddObject(PyObject *)&THPEngineType这个object加入到模块module中并命名为_ImperativeEngine。这个module的类型是PyObject,这个初始化函数可以在文件 torch/csrc/Module.cpp 的第 679 行找到,module的定义则是在第 67 行。

关于函数PyModule_AddObject的详细介绍可以参考 docs.python.org/3.5/c-a 。另外关于 Python 扩展的相关知识,可以参考 Extending and Embedding the Python Interpreter 进行学习,有机会可以把这块单拎出来再写一篇文章。

现在回过头来看之前的Variable._execution_engine.run_backward()其实就是_ImperativeEngine().run_backward()。从对象THPEngineType的定义可以找到run_backward也只是个外套,具体的C++函数其实是THPEngine_run_backward。这部分代码仍然是在 torch/csrc/autograd/python_engine.cpp 中。

// ...
 
static struct PyMethodDef THPEngine_methods[] = {
  {(char*)"run_backward", (PyCFunction)(void(*)(void))THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
  {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
  {(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr},
  {nullptr}
};

// ...

PyTypeObject THPEngineType = {
    // ...
    "torch._C._EngineBase",                      /* tp_name */
    // ...
    THPEngine_methods,                           /* tp_methods */
    // ...
};

// ...

代码中使用了PyMethodDef,该函数是用于描述扩展方法的struct。可以看到,除了我们要找的函数run_backward,此处还定义了函数queue_callback和函数is_checkpoint_valid

[5] THPEngine_run_backward

关于函数THPEngine_run_backward的介绍是Implementation of torch._C._EngineBase.run_backward,而torch._C._EngineBase这个名字在THPEngineType的定义部分的代码可以找到。该部分代码超过一百行了,下边分块来看一下。

首先把中间部分略去。函数内第一行和最后一行的HANDLE_TH_ERRORSEND_HANDLE_TH_ERRORS,是在文件 torch/csrc/Exceptions.h 中定义的宏,具体地,分别在第 41 行和第 114 行被定义。这部分代码主要是通过函数PyArg_ParseTupleAndKeywords对输入的参数重新解析并赋值给新定义的变量tensorsgrad_tensorskeep_graphcreate_graphinputs以及allow_unreachable

有关函数PyArg_ParseTupleAndKeywords的用法详见 docs.python.org/3.5/c-a

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
    HANDLE_TH_ERRORS
    _maybe_reinitialize_engine_after_fork();
    PyObject *tensors = nullptr;
    PyObject *grad_tensors = nullptr;
    unsigned char keep_graph = 0;
    unsigned char create_graph = 0;
    PyObject *inputs = nullptr;
    unsigned char allow_unreachable = 0;
    const char *accepted_kwargs[] = {
            "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
            "allow_unreachable", nullptr
    };
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
                &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
        return nullptr;

    // ...

    END_HANDLE_TH_ERRORS
}

下面来看下中间的部分。这部分主要是Check一下tensorsgrad_tensors的变量类型,并且检查二者的tuple size是否一致。

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
    // ...

    THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
            "be a tuple, but got %s", THPUtils_typename(tensors));
    THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
            "expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));

    Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
    Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
    THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
            "gradients", num_tensors, num_gradients);

    // ...
}

下边这部分代码也比较简单。先是定义edge_list roots;variable_list grads;。接下来通过循环把tensorsgrad_tensors中的元素push_backrootsgrads。具体地,先通过PyTuple_GET_ITEM取出元素,再利用((THPVariable*)···)->cdata取出元素的值。当然中间也会做一些Check,例如是否为Tensor之类的。

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
    // ...

    edge_list roots;
    roots.reserve(num_tensors);
    variable_list grads;
    grads.reserve(num_tensors);
    for (int i = 0; i < num_tensors; i++) {
        PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
        THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
                "tuple is not a Tensor", i);
        auto& variable = ((THPVariable*)_tensor)->cdata;
        auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
        THPUtils_assert(gradient_edge.function,
                "element %d of tensors does not require grad and does not have a grad_fn", i);
        roots.push_back(std::move(gradient_edge));

        PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
        if (THPVariable_Check(grad)) {
            const Variable& grad_var = ((THPVariable*)grad)->cdata;
            if (grad_var.has_names()) {
                TORCH_WARN(
                        "Autograd was passed a named grad tensor with dims ", grad_var.names(),
                        ". Autograd does not yet support named tensor semantics, so all names ",
                        "will be ignored. In practice all computed gradients will still be correct "
                        "according to regular tensor semantics.");
            }
            grads.push_back(grad_var);
        } else {
            THPUtils_assert(grad == Py_None,
                    "element %d of gradients tuple is not a Tensor or None", i);
            THPUtils_assert(!variable.requires_grad(),
                    "element %d of gradients tuple is None, but the corresponding Tensor requires grad");
        }
    }

    // ...
}

下边继续看。这部分代码就是对inputs中的每一个元素都传入函数torch::autograd::impl::try_get_grad_accumulator中去处理。函数try_get_grad_accumulator被定义在文件 torch/csrc/autograd/variable.h 的第 113 行,具体实现则是在文件 torch/csrc/autograd/variable.cpp 的第111 行,这个等下再说,现在只需要知道返回的是个指向Node对象的指针。接下来就是,如果指针不是空指针,则执行output_edges.emplace_back(grad_fn, output_nr)

函数push_back()emplace_back()的区别是,push_back()函数向容器中加入一个临时对象(右值元素)时, 首先会调用构造函数生成这个对象,然后条用拷贝构造函数将这个对象放入容器中, 最后释放临时对象。但是emplace_back()函数向容器中中加入临时对象, 临时对象原地构造,没有赋值或移动的操作。详细内容参阅 cpp/container/vector/emplace_back

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
    // ...

    std::vector<Edge> output_edges;
    if (inputs != nullptr) {
        int num_inputs = PyTuple_GET_SIZE(inputs);
        output_edges.reserve(num_inputs);
        for (int i = 0; i < num_inputs; ++i) {
            PyObject *input = PyTuple_GET_ITEM(inputs, i);
            THPUtils_assert(THPVariable_Check(input),
                    "all inputs have to be Tensors, but got %s", THPUtils_typename(input));
            THPVariable *input_var = (THPVariable*)input;
            const auto output_nr = input_var->cdata.output_nr();
            auto grad_fn = input_var->cdata.grad_fn();
            if (!grad_fn) {
                    grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata);
            }
            THPUtils_assert(input_var->cdata.requires_grad(),
                    "One of the differentiated Tensors does not require grad");
            if (!grad_fn) {
                output_edges.emplace_back();
            } else {
                output_edges.emplace_back(grad_fn, output_nr);
            }
        }
    }

    // ...
}

现在来看看传入output_edges的这两个参数都是什么类型。grad_fn是指向Node对象的std::shared_ptr指针,现在来看看另外一个参数output_nr。结构体THPVariable被定义在文件 torch/csrc/autograd/python_variable.h 中,代码如下所示。可以看到其中cdata变量的类型是torch::autograd::Variable。最终在 torch/csrc/autograd/VariableTypeManual.cpp 找到函数output_nr,其返回的是文件 torch/csrc/autograd/variable.h 中定义的结构体AutogradMeta中的成员变量uint32_t output_nr_;,这和文件 torch/csrc/autograd/edge.h 中定义的结构体Edge初始化的参数类型刚好吻合。

// python_variable.h

// Python object that backs torch.autograd.Variable
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THPVariable {
    PyObject_HEAD
    // Payload
    torch::autograd::Variable cdata;
    // Hooks to be run on backwards pass (corresponds to Python attr
    // '_backwards_hooks', set by 'register_hook')
    PyObject* backward_hooks = nullptr;
};

// =======================================================

// VariableTypeManual.cpp

int64_t output_nr(const Tensor & self) {
    if (impl::get_autograd_meta(self)) {
        return impl::get_autograd_meta(self)->output_nr_;
    } else {
        return 0;
    }
}

// =======================================================

// variable.h

struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
    // ...
    // The "output number" of this variable; e.g., if this variable
    // was the second output of a function, then output_nr == 1.
    // We use this to make sure we can setup the backwards trace
    // correctly when this variable is passed to another function.
    uint32_t output_nr_;
    // ...
};

// =======================================================

// edge.h

/// Represents a particular input of a function.
struct Edge {
    Edge() noexcept : function(nullptr), input_nr(0) {}
    Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
        : function(std::move(function_)), input_nr(input_nr_) {}
    // ...
};

该函数最后这部分的代码如下所示。注意到THPUtils_assert(allow_unreachable ... );,其中allow_unreachable flag可以追溯到上边第二部分的源码,通过Variable._execution_engine.run_backward传入的是allow_unreachable=TruePyTuple_GET_SIZE的作用是获取传入参数的sizePyTuple_New的作用是创建一个新的tuple对象,传入的参数就是新的tuple对象的sizePyTuple_SET_ITEM的作用是将THPVariable_Wrap(outputs[i])传入到py_outputs.get()的位置i处。这里最关键的就是函数engine.execute,我们下边具体介绍。

// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
    // ...

    variable_list outputs;
    {
        pybind11::gil_scoped_release no_gil;
        outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
    }

    if (inputs != nullptr) {
        int num_inputs = PyTuple_GET_SIZE(inputs);
        THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
        if (!py_outputs) return nullptr;
        for (int i = 0; i < num_inputs; i++) {
            THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
                                            "differentiated Tensors appears to not have been used "
                                            "in the graph. Set allow_unused=True if this is the "
                                            "desired behavior.");
            PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
        }
        return py_outputs.release();
    } else {
        Py_RETURN_NONE;
    }

    // ...
}

PyTuple_GET_SIZEPyTuple_New以及PyTuple_SET_ITEM这样的函数,用处就是可以在C++中操纵Python对象。关于tuple的类似的函数可以查阅 c-api/tuple 。这些其实也很好记,函数名带有PyTuple的就是PythonTuple对象,带有PyList的就是PythonList对象,带有PyType的就是PythonType对象。更多内容可以去看一下 docs.python.org/3.5/c-a

[6] try_get_grad_accumulator

现在回头来看一下函数try_get_grad_accumulator,定义在文件 torch/csrc/autograd/variable.h 的第 113 行,具体实现则是在文件 torch/csrc/autograd/variable.cpp 的第111 行。源码简化之后,如下所示。

// variable.h
// ...
namespace torch { namespace autograd {
struct Node; 
struct AutogradMeta;
struct DifferentiableViewMeta;
using Variable = at::Tensor;
namespace impl {
    // ...
    TORCH_API AutogradMeta* get_autograd_meta(const Variable&);
    // ...
    TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
    // ...
}
// ...
struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
    // ...
    std::weak_ptr<Node> grad_accumulator_; 
    // ...
};

// =================================

// variable.cpp
// ...
namespace torch {
namespace autograd {
// ...
namespace impl {
    // ...
    std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
        if (get_autograd_meta(self)) {
            return get_autograd_meta(self)->grad_accumulator_.lock();
        } else {
            return nullptr;
        }
    }
    // ...
    AutogradMeta* get_autograd_meta(const Variable& self) {
        // NB: could return null
        TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor");
        return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta());
    }
    // ...
}
// ...

所以函数try_get_grad_accumulator就是先通过函数get_autograd_meta返回一个AutogradMeta结构体,然后访问结构体中的成员变量grad_accumulator_,而grad_accumulator_是一个指向类型为Node对象的std::weak_ptr指针。lock()函数的作用是创建一个std::shared_ptr来管理对象,try_get_grad_accumulator函数的返回类型是std::shared_ptr<Node>

weak_ptr设计的目的是为配合shared_ptr而引入的一种智能指针,详见 en.cppreference.com/w/cen.cppreference.com/w/c

[7] engine.execute(roots, grads, keep_graph, create_graph, output_edges)

接着上边第 5 部分继续来看,最重要的variable_list outputs;的值是由函数engine.execute得到的。engine的定义如下,在文件 torch/csrc/autograd/python_engine.cpp 的第 26 行。

static torch::autograd::python::PythonEngine engine; 

torch::autograd::python::PythonEngine的定义在文件 torch/csrc/autograd/python_engine.h 中,代码如下所示。结构体PythonEngine继承自结构体Engine,而其中的方法execute也是重载的Engine::execute函数,所以我们要讨论的函数就变成了Engine::execute(roots, inputs, keep_graph, create_graph, outputs)

// python_engine.h

namespace torch { namespace autograd { namespace python {

struct PythonEngine : public Engine {
  void thread_init(int device) override;
  void thread_on_exception(
      std::shared_ptr<GraphTask>& graph_task,
      const std::shared_ptr<Node>& fn,
      std::exception& e) override;
  variable_list execute(
      const edge_list& roots,
      const variable_list& inputs,
      bool keep_graph,
      bool create_graph,
      const edge_list& outputs = {}) override;

  variable_list execute_with_graph_task(
      std::shared_ptr<GraphTask> graph_task,
      std::shared_ptr<Node> graph_root) override;
  std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override;
};

}}} // namespace torch::autograd::python

// ================================================

// python_engine.cpp

variable_list PythonEngine::execute(
    const edge_list& roots,
    const variable_list& inputs,
    bool keep_graph,
    bool create_graph,
    const edge_list& outputs) {
  try {
    return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
  } catch (python_error& e) {
    e.restore();
    throw;
  }
}

下边来看下文件 torch/csrc/autograd/engine.htorch/csrc/autograd/engine.cpp

接下来的内容高能 !!!

层层嵌套的调用让人眼花缭乱 !!!

variable_list Engine::execute(const edge_list& roots, 
                              const variable_list& inputs, 
                              bool keep_graph, 
                              bool create_graph, 
                              const edge_list& outputs)
// 调用 ↓
variable_list Engine::execute_with_graph_task(std::shared_ptr<GraphTask> graph_task,
                                              std::shared_ptr<Node> graph_root)
// 调用 ↓
void Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task,
                         bool reentrant_thread)
// 调用 ↓
void Engine::evaluate_function(std::shared_ptr<GraphTask>& graph_task,
                               Node* func, 
                               InputBuffer& inputs)
// 调用 ↓ (这个函数不是 Engine 结构体中的方法)
variable_list call_function(std::shared_ptr<GraphTask>& graph_task,
                            Node* func,
                            InputBuffer& inputBuffer) {
    // ...
    auto& fn = *func;
    auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
    variable_list outputs = fn(std::move(inputs));
    // ...
    if(has_post_hooks){
        // NOLINTNEXTLINE(bugprone-use-after-move)
        return call_post_hooks(fn, std::move(outputs), inputs);
    }
    return outputs;
}

// =================================

static variable_list call_pre_hooks(Node& fn, 
                                    variable_list inputs) {
    for (const auto& hook : fn.pre_hooks()) {
        inputs = (*hook)(inputs);
    }
    return inputs;
}

static variable_list call_post_hooks(Node& fn, 
                                     variable_list outputs, 
                                     const variable_list& inputs) {
    for (const auto& hook : fn.post_hooks()) {
        outputs = (*hook)(outputs, inputs);
    }
    return outputs;
}

下边来整理下这些函数里遇到的结构体。首先是结构体Node,被定义在文件 torch/csrc/autograd/function.h 的第 87 行。关于Node,其表示一个操作,可以理解成Autograd Graph中的顶点vertice。结构体GraphTask被定义在文件 torch/csrc/autograd/engine.h 的第 38 行,其作用是GraphTask holds metadata needed for a single execution of backward()

函数fn.post_hooks()fn.pre_hooks()分别返回结构体成员变量post_hooks_pre_hooks_,二者类型分别为std::vector<std::unique_ptr<FunctionPostHook>>std::vector<std::unique_ptr<FunctionPreHook>>。这里又涉及到了一个结构体struct FunctionPreHook。关于指针unique_ptr,与shared_ptr不同,某个时刻只能有一个unique_ptr指向一个给定的对象;当unique_ptr被销毁时,它所指向的对象也被销毁,uniptr_ptr表达的是一种独占的思想。说回结构体FunctionPreHookFunctionPostHook,这两个结构体都被定义在文件 torch/csrc/autograd/function_hook.h 中。

#pragma once

#include <vector>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/Tensor.h>

// A hook that's called on gradients

namespace torch { namespace autograd {

using Variable = at::Tensor;
using variable_list = std::vector<Variable>;

struct TORCH_API FunctionPreHook {
    virtual ~FunctionPreHook();
    virtual variable_list operator()(const variable_list& grads) = 0;
};

struct TORCH_API FunctionPostHook {
    virtual ~FunctionPostHook();
    virtual variable_list operator()(
        const variable_list& outputs /* grad_inputs */,
        const variable_list& inputs /* grad_outputs */) = 0;
};

}} // namespace torch::autograd

本文就介绍到这里,如有错误之处,欢迎批评指正。


这代码,真的服了~

发布于 2019-12-16 15:48