本文基于PyTorch1.7.0,https://github.com/pytorch/pytorch/tree/v1.7.0 如果本文有不清楚或者不正确的地方,请在评论区指正

动态计算图之反向传播中介绍了反向传播的调用过程,反向传播最终是调用Engineexecute()完成的。
Engine类在已建好动态计算图的基础上反向传播计算结点的梯度。Engine类的执行是从execute()方法开始的。

动态计算图之反向传播中并未介绍Engine类,本文将追踪execute()调用过程并介绍Engine类以及相关类。

  1. import torch
  2. a = torch.tensor(1.0, requires_grad=True)
  3. b = torch.tensor(2.0, requires_grad=True)
  4. c = torch.add(a, b)
  5. d = torch.mul(a, c)
  6. d.backward()
  7. print(f"a grad:{a.grad} grad_fn:{a.grad_fn}")
  8. print(f"b grad:{b.grad} grad_fn:{b.grad_fn}")
  9. print(f"c grad:{c.grad} grad_fn:{c.grad_fn}")
  10. print(f"d grad:{d.grad} grad_fn:{d.grad_fn}")
  11. """
  12. a grad:4.0 grad_fn:None
  13. b grad:1.0 grad_fn:None
  14. c grad:None grad_fn:<AddBackward0 object at 0x7fdb27bcbf90>
  15. d grad:None grad_fn:<MulBackward0 object at 0x7fdb27bcb210>
  16. """

dag_forward_1 (1).svg

execute()

1.1 参数含义

  1. auto Engine::execute(const edge_list& roots,
  2. const variable_list& inputs,
  3. bool keep_graph,
  4. bool create_graph,
  5. bool accumulate_grad,
  6. const edge_list& outputs)

1.1.1 roots

rootsedge_list变量,表示计算图的根结点或者起点。

[torch/csrc/autograd/function.h]
using edge_list = std::vector<Edge>;
  • Edge

Edge是结点之间的有向边,表示为[function,input_nr]
function表示边指向的结点,input_nr表示第几个输入,事实上,这里的input_nr表示正向传播时的第几个输出,例如Add操作,只输出1个数,那么AddBackward的对应input_nr为0。

[torch/csrc/autograd/edge.h]
struct Edge {
  Edge() noexcept : function(nullptr), input_nr(0) {}

  Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
      : function(std::move(function_)), input_nr(input_nr_) {}

  /// The function this `Edge` points to.
  std::shared_ptr<Node> function;

  /// The identifier of a particular input to the function.
  uint32_t input_nr;
};

1.1.2 inputs

using variable_list = std::vector<Variable>;

inputs表示Node的输入,其中输入为Variable列表。

1.1.3 keep_graph

是否保存动态计算图

1.1.4 create_graph

是否为反向传播的计算建立动态计算图

1.1.5 accumulate_grad

1.1.6 ouputs

1.2 初始化local_ready_queue

[torch/csrc/autograd/engine.cpp]
auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {
  ...
  // A frech first time Engine::execute call should start on the CPU device, initialize
  // a new thread local ready queue on CPU or reuse the existing one (if there is one
  // allocated already, i.e. consecutive backward calls, re-entrant backward calls),
  // then memoize the local_ready_queue in GraphTask
  init_local_ready_queue();
  ...
}

init_local_ready_queue()源码如下

[torch/csrc/autograd/engine.cpp]
...
static thread_local std::shared_ptr<ReadyQueue> local_ready_queue = nullptr;
...
void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
  if (ready_queue) {
    // if ready_queue provided in the caller, use the caller's ready_queue to initialize local_ready_queue
    local_ready_queue = std::move(ready_queue);
  } else if (!local_ready_queue){
    // otherwise if local_ready_queue not allocated, allocate a new ready_queue
    local_ready_queue = std::make_shared<ReadyQueue>();
  }
}
  • ReadyQueue

ReadyQueue用优先队列维持NodeTask

[torch/csrc/autograd/engine.h]
struct ReadyQueue {
 private:
  // Returns true when t2 should be (weakly) BEFORE t1 in the queue.
  // Shutdown tasks are first and then empty NodeTask are next.
  struct CompareNodeTaskTime {
    bool operator()(NodeTask const & t1, NodeTask const & t2) {
      ...
    }
  };

  // To notify threads waiting on the ReadyQueue of available tasks on the heap_
  std::condition_variable not_empty_;
  // To protect read and writes to heap_
  mutable std::mutex mutex_;
  std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime> heap_;
  ...
};

1.3 创建GraphTask实例

auto graph_task = std::make_shared<GraphTask>(
      /* keep_graph */ keep_graph,
      /* create_graph */ create_graph,
      /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
      /* cpu_ready_queue */ local_ready_queue);
  • GraphTask

GraphTask保存用于单次执行backward()的元信息

[torch/csrc/autograd/engine.h]
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask: std::enable_shared_from_this<GraphTask> {
  std::atomic<uint64_t> outstanding_tasks_{0};
  // Indicates if an error occurred while executing any task.  When this is
  // true, it signals all threads to stop executing.
  std::atomic_bool has_error_{false};
  std::atomic_bool future_completed_{false};
  // It is safe to read grad_mode_ and keep_graph_ without synchronization
  bool keep_graph_;
  bool grad_mode_;

  // To protect reads/writes to not_ready_, dependencies_, captured_vars_,
  // has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
  std::mutex mutex_;
  std::unordered_map<Node*, InputBuffer> not_ready_;
  std::unordered_map<Node*, int> dependencies_;

  at::ThreadLocalState thread_locals_ =
      at::ThreadLocalState(/* keep_grad_mode */ false);

  std::unordered_set<c10::Stream> leaf_streams;

  void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad);

  // The value of worker_device in the thread that created this task.
  // See Note [Reentrant backwards]
  // Safe to read owner_ and reentrant_depth_ without synchronizaton
  int owner_;
  // The number of parent graph tasks for this graph task
  const int reentrant_depth_;

  // check if the GraphTask is completed or not
  bool completed();
  // mark the graph task as completed and trigger post processing
  void mark_as_completed_and_run_post_processing();

  // CPU threads are dedicated to processing CPU work for the backward they invoked.
  // So any given graph task maintains its own cpu_ready_queue_ where you should send
  // work for it to be done. We memoize the cpu_ready_queue_ per GraphTask so that
  // we know which ready queue we should push to if we are on device thread (i.e. GPU)
  // and but next NodeTask should be run on CPU.
  std::shared_ptr<ReadyQueue> cpu_ready_queue_;

  // Future representing the completion of the graph task. Notified when all
  // tasks are done.
  std::shared_ptr<at::ivalue::Future> future_result_;
  ...
};

1.4 创建GraphRoot实例

[torch/csrc/autograd/engine.cpp]
auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {
  ...
  // If we receive a single root, skip creating extra root node
  bool skip_dummy_node = roots.size() == 1;
  auto graph_root = skip_dummy_node ?
  roots.at(0).function :
  std::make_shared<GraphRoot>(roots, inputs);
  ...
}

如果输入是单一结点,则该单一结点为图的起点,否则重新创建一个GraphRoot实例作为动态计算图的起点。

1.5 计算依赖

[torch/csrc/autograd/engine.cpp]
auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {
  ...
  // Now compute the dependencies for all executable functions and queue the root
  compute_dependencies(graph_root.get(), *graph_task);
  ...
}

计算每个Node的依赖(图中结点的入度),即有多少条边指向该Node。

[torch/csrc/autograd/engine.cpp]
/* Computes the number of dependencies for each function which requires grad */
auto Engine::compute_dependencies(Node* root, GraphTask& task) -> void {
  // Just to make sure that they will never be added to the queue again
  std::unordered_set<Node*> seen;
  std::vector<Node*> queue { root };

  // Queue contains all nodes that will start propagating gradients.
  // We no longer have to expand functions that don't require grad.
  auto& dependencies = task.dependencies_;
  while (!queue.empty()) {
    auto fn = queue.back(); queue.pop_back();
    for (const auto& edge : fn->next_edges()) {
      if (auto next_ptr = edge.function.get()) {
        dependencies[next_ptr] += 1;
        const bool was_inserted = seen.insert(next_ptr).second;
        if (was_inserted) queue.push_back(next_ptr);
      }
    }
  }
}

1.6 执行图之前的准备

[torch/csrc/autograd/engine.cpp]
auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {
  ...
  if (skip_dummy_node) {
    InputBuffer input_buffer(roots.at(0).function->num_inputs());
    auto input = inputs.at(0);

    const auto input_stream = InputMetadata(input).stream();
    const auto opt_next_stream = roots.at(0).function->stream(c10::DeviceType::CUDA);
    input_buffer.add(roots.at(0).input_nr,
                      std::move(input),
                      input_stream,
                      opt_next_stream);

    ...
  } else {
    ...
  }
  ...
}

InputBuffer类

The InputBuffer class accumulates a list of Variables for use by a function. It implements logic to avoid modifying the passed values in-place (adding an input twice will accumulate the result). This behaviour is needed and used only in backward graphs.

[torch/csrc/autograd/input_buffer.h]
struct InputBuffer {
  ...
  explicit InputBuffer(size_t size)
    : buffer(size) {}
  explicit InputBuffer(variable_list&& inputs): buffer(std::move(inputs)) {};
  // Accumulates the variable at a specified index.
  // The optional CUDA streams determine which stream the accumulation
  // is run on and how the addition is synchronized.
  void add(size_t pos,
           Variable&& var,
           const c10::optional<c10::Stream>& opt_producer_stream,
           const c10::optional<c10::Stream>& opt_consumer_stream);
  ...
  // Returns the inputs as a list of variables. Destroys given InputBuffer.
  static std::vector<Variable> variables(InputBuffer&& g);
private:
  std::vector<Variable> buffer;
};

1.7 执行图任务

[torch/csrc/autograd/engine.cpp]
auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {
  ...
  if (skip_dummy_node) {
    ...
    execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } else {
    ...
  }
  ...
}

execute_with_graph_task()会调用PythonEngine::execute_with_graph_task()

[torch/csrc/autograd/python_engine.cpp]
std::shared_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {
  try {
    return Engine::execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } catch (python_error& e) {
    pybind11::gil_scoped_acquire gil;
    if (!PyErr_Occurred()) {
      // Set the error indicator only if it is not set already.
      e.restore();
    }
    throw;
  }
}

PythonEngine::execute_with_graph_task()调用Engine::execute_with_graph_task()
下面分析execute_with_graph_task()的执行流程。

1.7.1 初始化线程池

如果对线程或者线程池不了解,建议阅读c++ concurrency action 线程池

[torch/csrc/autograd/engine.cpp]
std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {
  initialize_device_threads_pool();
  ...
}

initialize_device_threads_pool()调用start_device_threads()完成线程池的初始化

[torch/csrc/autograd/engine.cpp]
void Engine::initialize_device_threads_pool() {
  track_bad_autograd_forks();
  TORCH_CHECK(!in_bad_autograd_fork,
              "Unable to handle autograd's threading in combination with fork-based multiprocessing. "
              "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork");
  std::call_once(start_device_threads_flag_, &Engine::start_device_threads, this);
}
  • start_device_threads()

    [torch/csrc/autograd/engine.cpp]
    auto Engine::start_device_threads() -> void {
    // See Note [Allocating GPUs to autograd threads]
    c10::DeviceIndex num_devices = 0;
    for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
      auto* impl = impl_atomic.load();
      if (impl) {
        num_devices = std::max(num_devices, impl->deviceCount());
      }
    }
    
    // allocate one thread for every GPU device (but colocate GPUs of different
    // types), and pre-allocate the device_ready_queues_ to ensure safe reading on it.
    device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices);
    for (auto& queue : device_ready_queues_)    {
      queue.reset(new ReadyQueue());
    }
    
    thread_pool_shared_ = std::make_shared<ThreadPoolShared>();
    
    for (int i = 0; i < num_devices; ++i) {
      std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true);
      t.detach();
    }
    // Wait for the threads to start
    {
      std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
      while(non_reentrant_device_thread_count_.load() != num_devices) {
        non_reentrant_device_thread_condvar_.wait(lk);
      }
    }
    }
    

    20-23行创建线程对象,线程对象需要执行的函数是thread_init()

    [torch/csrc/autograd/engine.cpp]
    void Engine::thread_init(int device, const std::shared_ptr<ReadyQueue>& ready_queue, bool should_increment) {
    if (should_increment) {
      increment_non_reentrant_thread_count();
    }
    
    at::init_num_threads();
    set_device(device);
    
    // initialize each device thread's thread local ready queue with the ready queue
    // that is created before the thread initialization
    init_local_ready_queue(ready_queue);
    
    std::shared_ptr<GraphTask> graph_task = nullptr;
    thread_main(graph_task);
    if (should_increment) {
      // Decrement the count during shutdown if we incremented earlier.
      decrement_non_reentrant_thread_count();
    }
    }
    

    1.7.2 Lock mutex for GraphTask

    [torch/csrc/autograd/engine.cpp]
    std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
      const std::shared_ptr<GraphTask>& graph_task,
      std::shared_ptr<Node> graph_root,
      InputBuffer&& input_buffer) {
    ...
    std::unique_lock<std::mutex> lock(graph_task->mutex_);
    ...
    }
    

    1.7.3 队列

    [torch/csrc/autograd/engine.cpp]
    std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
      const std::shared_ptr<GraphTask>& graph_task,
      std::shared_ptr<Node> graph_root,
      InputBuffer&& input_buffer) {
    ..
    auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
    ...
    }
    
  • ready_queue()返回队列

    // CPU ready queue is per GraphTask, but CUDA device ready queues are shared across all graph tasks
    auto Engine::ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue, at::Device device) -> std::shared_ptr<ReadyQueue>{
    if (device.type() == at::kCPU) {
      // return the cpu ready queue passed in
      TORCH_INTERNAL_ASSERT(cpu_ready_queue);
      return cpu_ready_queue;
    } else {
      // See Note [Allocating GPUs to autograd threads]
      return device_ready_queues_.at(device.index());
    }
    }
    
  • NodeTask

    [torch/csrc/autograd/engine.h]
    struct NodeTask {
    std::weak_ptr<GraphTask> base_;
    std::shared_ptr<Node> fn_;
    // This buffer serves as an implicit "addition" node for all of the
    // gradients flowing here.  Once all the dependencies are finished, we
    // use the contents of this buffer to run the function.
    InputBuffer inputs_;
    // When worker receives a task with isShutdownTask = true, it will immediately
    // exit. The engine sends a shutdown task to every queue upon its destruction.
    bool isShutdownTask_;
    
    int getReentrantDepth() const;
    
    NodeTask(
        std::weak_ptr<GraphTask> base,
        std::shared_ptr<Node> fn,
        InputBuffer inputs,
        bool isShutdownTask = false)
        : base_(base),
          fn_(std::move(fn)),
          inputs_(std::move(inputs)),
          isShutdownTask_(isShutdownTask) {}
    };
    

    1.7.4 执行图任务

    [torch/csrc/autograd/engine.cpp]
    std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
      const std::shared_ptr<GraphTask>& graph_task,
      std::shared_ptr<Node> graph_root,
      InputBuffer&& input_buffer) {
    ..
    thread_main(graph_task);
    TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
    ...
    }
    

    第7行调用thread_main()开始计算图的反向传播,第8行等待任务完成。

  • thread_main()

    [torch/csrc/autograd/engine.cpp]
    auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
    // When graph_task is nullptr, this is a long running thread that processes
    // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
    // backwards, user thread), this function is expected to exit once that
    // graph_task complete.
    
    // local_ready_queue should already been initialized when we get into thread_main
    TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr);
    while (graph_task == nullptr || !graph_task->future_result_->completed()) {
      // local_graph_task represents the graph_task we retrieve from the queue.
      // The outer graph_task represents the overall graph_task we need to execute
      // for reentrant execution.
      std::shared_ptr<GraphTask> local_graph_task;
      {
        // Scope this block of execution since NodeTask is not needed after this
        // block and can be deallocated (release any references to grad tensors
        // as part of inputs_).
        NodeTask task = local_ready_queue->pop();
        // This will only work if the worker is running a non backward task
        // TODO Needs to be fixed this to work in all cases
        if (task.isShutdownTask_) {
          C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
          break;
        }
    
        if (!(local_graph_task = task.base_.lock())) {
          // GraphTask for function is no longer valid, skipping further
          // execution.
          continue;
        }
    
        if (task.fn_ && !local_graph_task->has_error_.load()) {
          AutoGradMode grad_mode(local_graph_task->grad_mode_);
          try {
            // The guard sets the thread_local current_graph_task on construction
            // and restores it on exit. The current_graph_task variable helps
            // queue_callback() to find the target GraphTask to append final
            // callbacks.
            GraphTaskGuard guard(local_graph_task);
            NodeGuard ndguard(task.fn_);
            evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
          } catch (std::exception& e) {
            thread_on_exception(local_graph_task, task.fn_, e);
          }
        }
      }
    
      // Decrement the outstanding tasks.
      --local_graph_task->outstanding_tasks_;
    
      // Check if we've completed execution.
      if (local_graph_task->completed()) {
        local_graph_task->mark_as_completed_and_run_post_processing();
    
        auto base_owner = local_graph_task->owner_;
        // The current worker thread finish the graph_task, but the owning thread
        // of the graph_task might be sleeping on pop() if it does not have work.
        // So we need to send a dummy function task to the owning thread just to
        // ensure that it's not sleeping, so that we can exit the thread_main.
        // If it has work, it might see that graph_task->outstanding_tasks_ == 0
        // before it gets to the task, but it's a no-op anyway.
        //
        // NB: This is not necessary if the current thread is the owning thread.
        if (worker_device != base_owner) {
          // Synchronize outstanding_tasks_ with queue mutex
          std::atomic_thread_fence(std::memory_order_release);
          ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
              ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
        }
      }
    }
    }
    

    thread_main()while循环体执行图的计算,thread_main()调用evaluate_function()完成一个Node结点的处理。

    [torch/csrc/autograd/engine.cpp]
    void Engine::evaluate_function(
      std::shared_ptr<GraphTask>& graph_task,
      Node* func,
      InputBuffer& inputs,
      const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
    ...
    // Switches to a function's CUDA stream (if applicable) before calling it
    const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
    c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
    
    auto outputs = call_function(graph_task, func, inputs);
    
    auto& fn = *func;
    if (!graph_task->keep_graph_) {
      fn.release_variables();
    }
    
    int num_outputs = outputs.size();
    if (num_outputs == 0) { // Note: doesn't acquire the mutex
      // Records leaf stream (if applicable)
      // See note "Streaming backwards"
      if (opt_parent_stream) {
        std::lock_guard<std::mutex> lock(graph_task->mutex_);
        graph_task->leaf_streams.emplace(*opt_parent_stream);
      }
      return;
    }
    
    if (AnomalyMode::is_enabled()) {
      AutoGradMode grad_mode(false);
      for (int i = 0; i < num_outputs; ++i) {
        auto& output = outputs[i];
        at::OptionalDeviceGuard guard(device_of(output));
        if (output.defined() && isnan(output).any().item<uint8_t>()) {
          std::stringstream ss;
          ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
          throw std::runtime_error(ss.str());
        }
      }
    }
    
    // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
    std::lock_guard<std::mutex> lock(graph_task->mutex_);
    for (int i = 0; i < num_outputs; ++i) {
      auto& output = outputs[i];
      const auto& next = fn.next_edge(i);
    
      if (!next.is_valid()) continue;
    
      // Check if the next function is ready to be computed
      bool is_ready = false;
      auto& dependencies = graph_task->dependencies_;
      auto it = dependencies.find(next.function.get());
    
      if (it == dependencies.end()) {
        auto name = next.function->name();
        throw std::runtime_error(std::string("dependency not found for ") + name);
      } else if (--it->second == 0) {
        dependencies.erase(it);
        is_ready = true;
      }
    
      auto& not_ready = graph_task->not_ready_;
      auto not_ready_it = not_ready.find(next.function.get());
      if (not_ready_it == not_ready.end()) {
        // Skip functions that aren't supposed to be executed
        if (!exec_info_.empty()) {
          auto it = exec_info_.find(next.function.get());
          if (it == exec_info_.end() || !it->second.should_execute()) {
            continue;
          }
        }
        // No buffers have been allocated for the function
        InputBuffer input_buffer(next.function->num_inputs());
    
        // Accumulates into buffer
        const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
        input_buffer.add(next.input_nr,
                         std::move(output),
                         opt_parent_stream,
                         opt_next_stream);
    
        if (is_ready) {
          auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
          queue->push(
              NodeTask(graph_task, next.function, std::move(input_buffer)));
        } else {
          not_ready.emplace(next.function.get(), std::move(input_buffer));
        }
      } else {
        // The function already has a buffer
        auto &input_buffer = not_ready_it->second;
    
        // Accumulates into buffer
        const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
        input_buffer.add(next.input_nr,
                         std::move(output),
                         opt_parent_stream,
                         opt_next_stream);
        if (is_ready) {
          auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
          queue->push(
              NodeTask(graph_task, next.function, std::move(input_buffer)));
          not_ready.erase(not_ready_it);
        }
      }
    }
    }
    

    call_function()计算Node结点的结果

  • call_function()

    static variable_list call_function(
      std::shared_ptr<GraphTask>& graph_task,
      Node* func,
      InputBuffer& inputBuffer) {
    bool prev_checkpoint_valid_state = checkpoint_valid;
    checkpoint_valid =
        graph_task->can_checkpoint() && prev_checkpoint_valid_state;
    auto& fn = *func;
    auto inputs =
        call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
    
    if (!graph_task->keep_graph_) {
      fn.will_release_variables();
    }
    
    const auto has_post_hooks = !fn.post_hooks().empty();
    variable_list outputs;
    
    {
      at::ThreadLocalStateGuard guard(graph_task->thread_locals_);
      if (has_post_hooks) {
        // In functions/accumulate_grad.cpp, there is some logic to check the
        // conditions under which the incoming gradient can be stolen directly
        // (which elides a deep copy) instead of cloned. One of these conditions
        // is that the incoming gradient's refcount must be 1 (nothing else is
        // referencing the same data).  Stashing inputs_copy here bumps the
        // refcount, so if post hooks are employed, it's actually still ok for
        // accumulate_grad.cpp to steal the gradient if the refcount is 2.
        //
        // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
        // accumulate_grad.cpp accounts for this, but also creates a silent
        // dependency between engine.cpp (ie, this particular engine
        // implementation) and accumulate_grad.cpp.
        //
        // If you change the logic here, make sure it's compatible with
        // accumulate_grad.cpp.
        auto inputs_copy = inputs;
        outputs = fn(std::move(inputs_copy));
      } else {
        outputs = fn(std::move(inputs));
      }
    }
    
    validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
      std::ostringstream ss;
      ss << "Function "  << fn.name() << " returned an " << msg;
      return ss.str();
    });
    checkpoint_valid = prev_checkpoint_valid_state;
    
    if(has_post_hooks){
      // NOLINTNEXTLINE(bugprone-use-after-move)
      return call_post_hooks(fn, std::move(outputs), inputs);
    }
    return outputs;
    }