/*! * * \file src/relay/pass/vacc/forward_graph.h * * \brief This is a indexed data flow graph in forward direction. */#include <tvm/expr_operator.h>#include <tvm/relay/analysis.h>#include <tvm/relay/expr_functor.h>#include <tvm/relay/op_attr_types.h>#include <tvm/relay/transform.h>#include "../../../common/arena.h"#include "../pattern_util.h"namespace tvm {namespace relay {using common::LinkedList;using common::LinkNode;/*! * \brief Indexed data flow graph in forward direction. * This is a temporary data structure used for operator fusion analysis. * * This data structure only captures the dataflow fragement and * could ignore blocks like let by simply ordering each dataflow block * and mark the output node as extern_ref; *//*说明:IndexedForwardGraph对应的树和tvm中的graph是相反的。graph遍历的入口是FunctionNode,在访问该FunctionNode时会将与它相关的tvm::Node通过Update添加到graph_.node_map中这样一来,递归访问上面的tvm::Node的时候就能通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)一般的,在访问某个tvm::Node时(以CallNode为例):1 通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)2 将当前tvm::Node所引用的tvm::Node(包括op对应的OpNode,args中的各个expr对应的Node),通过Update添加到graph_.node_map中(其中的参数parent就是1中获取的Node)这样一来,递归访问上面的tvm::Node的时候就能通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)3 调用ExprVisitor::VisitExpr_(call);来递归访问上面grpah4待当前tvm::Node之前的graph全部访问完毕,调用this->AddNode(call);将当前tvm::Node对应的Node(SimpleIndexedForwardGraph::Node)添加到graph_.node_mappost_dfs_order所以,graph_.node_mappost_dfs_order中,原graph的叶节点对应位置0,而原graph的叶节点对应最后位置*/class SimpleIndexedForwardGraph { public: struct Node; /*! * The forward edge in the dataflow graph. */ struct Edge { /*! \brief The corresponding node */ Node* node{nullptr}; /*! \brief The respective pattern of this op */ OpPatternKind pattern{kOpaque}; }; /*! \brief A node in the graph. */ struct Node { /*! \brief weak reference to the corresponding edge. */ const tvm::Node* ref{nullptr}; /*! \brief The index of the node in topological order. */ size_t index{0}; /*! \brief Whether this node is referenced by external source, 即是否是输出节点 */ bool extern_ref{false}; /*! \brief The general pattern in the node */ OpPatternKind pattern{kOpaque}; /*! \brief The outputs of the node. */ // 在graph中从上往下为Forward顺序,outputs就是引用当前Node的其它Node LinkedList<Edge> outputs; /*! * \brief Get all tvm::Node which refer to this node * \return std::vector<tvm::Node *> */ std::vector<const tvm::Node*> GetRefs() { std::vector<const tvm::Node*> nodes; for (auto* link = outputs.head; link != nullptr; link = link->next) { nodes.push_back(link->value.node->ref); } return std::move(nodes); } }; /*! \brief The node map that maps node to graph */ // 给定tvm::Node, 可以通过node_map找到对应的IndexedForwardGraph中的Node // 然后根据Node的outputs找到引用tvm::Node的一个或多个tvm::Node std::unordered_map<const tvm::Node*, Node*> node_map; /*! \brief All the nodes in post DFS order */ // graph中从上往下的顺序 std::vector<Node*> post_dfs_order; /*! \brief Dump the graph into string. */ void DebugDump() { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; os << "node[" << i << "], " << GetRef<NodeRef>(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } os << "]\n"; } LOG(INFO) << os.str(); } /*! * \brief create a indexed forward graph. * \param arena The arena used for data allocation. * \param body The body of the expression to create a graph. */ static SimpleIndexedForwardGraph Create(common::Arena* arena, const Expr& body); private: class Creator;};// Creator of post dominator tree of the dataflowclass SimpleIndexedForwardGraph::Creator : private ExprVisitor { public: explicit Creator(common::Arena* arena) : arena_(arena) {} SimpleIndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); this->VisitExpr(body); return std::move(graph_); } private: /*! \brief allocator of all the internal node object */ common::Arena* arena_; // The output. SimpleIndexedForwardGraph graph_; // attribute equal comparator AttrsEqual attr_equal_; // Update the message stored at the node. // 更新graph_.node_map // 其实就是给node对应的IndexedForwardGraph::Node添加一条输出边IndexedForwardGraph::Edge // 在当前节点就将输入节点添加到graph_.node_map中,同时将当前节点添加到输入节点的outputs中 void Update(const Expr& node, SimpleIndexedForwardGraph::Node* parent, OpPatternKind pattern) { // 先根据node对应的tvm::Node在graph_找到对应的IndexedForwardGraph::Node* current // 如果没有找到就新建一个 const tvm::Node* key = node.get(); SimpleIndexedForwardGraph::Node* current; auto it = graph_.node_map.find(key); if (it != graph_.node_map.end()) { current = it->second; } else { current = arena_->make<SimpleIndexedForwardGraph::Node>(); graph_.node_map[key] = current; } if (parent != nullptr) { auto* link = arena_->make<LinkNode<SimpleIndexedForwardGraph::Edge> >(); link->value.node = parent; link->value.pattern = pattern; current->outputs.Push(link); } else { current->extern_ref = true; //当前Node是输出节点 } } // 添加一个tvm::Node // 必须确保该tvm::Node对应的IndexedForwardGraph::Node存在,且IndexedForwardGraph::Node没有被引用。 // 然后更新该IndexedForwardGraph::Node的ref, index, 并添加到graph_.post_dfs_order中 // 在访问tvm::Node才调用,添加当前的Node void AddNode(const tvm::Node* key) { auto it = graph_.node_map.find(key); CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<NodeRef>(key); SimpleIndexedForwardGraph::Node* node = it->second; CHECK(node->ref == nullptr); node->ref = key; node->index = graph_.post_dfs_order.size(); graph_.post_dfs_order.push_back(node); } // Post order tree void VisitExpr_(const FunctionNode* op) final { for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } this->Update(op->body, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const ConstantNode* op) final { //在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map this->AddNode(op); Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. bool is_simple_const = (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) || dtype == DataType::Float(64) || dtype == DataType::Bool()); if (op->is_scalar() && is_simple_const) { node->pattern = kElemWise; } else { // for now, mark non-scalar constant // as opaque, we will not choose to fuse it. node->pattern = kOpaque; } } void VisitExpr_(const CallNode* call) final { //在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern"); // Now we set the pattern of this call. // // If we see a call mentioning an operator we should mark it with its // annotated pattern. // // If the pattern is not annotated we will default to opaque. // // Finally if the operator position is not a call node we will // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; const OpNode* opnode = call->op.as<OpNode>(); if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) { op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]); // if(opnode->name == "nn.max_pool2d") op_pattern = kElemWise; } else { this->Update(call->op, node, kOpaque); } node->pattern = op_pattern; this->Update(call->op, nullptr, kOpaque); // OpNode没有输出边 const auto* rtype = call->checked_type().as<TensorTypeNode>(); // pass the analysis back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>(); // specifically check if result type is the same as arguments type OpPatternKind edge_pattern = op_pattern; if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr && attr_equal_(rtype->shape, arg_type->shape)) { edge_pattern = kElemWise; } // 当前节点的输入节点的输出节点就是当前节点 this->Update(call->args[i], node, edge_pattern); } ExprVisitor::VisitExpr_(call); this->AddNode(call); } void VisitExpr_(const TupleNode* op) final { //在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map CHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as<TensorTypeNode>()) { this->Update(field, tuple_node, kInjective); } else { this->Update(field, nullptr, kOpaque); } } ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const TupleGetItemNode* op) final { auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>(); CHECK(tuple_type); // when TVM lowers a fused function, it expects all arguments to be a Tensor or // a tuple containing only Tensors. But this tuple may contain a reference or // another tuple. To avoid modifying codegen logic, we do not allow fusing through this node // if the tuple contains such non Tensor fields. However, all fields will be recursively // visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods. bool has_non_tensor = false; for (auto ty : tuple_type->fields) { if (!ty.as<TensorTypeNode>()) { has_non_tensor = true; break; } } if (has_non_tensor) { this->Update(op->tuple, nullptr, kOpaque); } else { CHECK(graph_.node_map.count(op)); Node* node = graph_.node_map.at(op); node->pattern = kInjective; this->Update(op->tuple, node, kInjective); } ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const VarNode* op) final { this->AddNode(op); } void VisitExpr_(const LetNode* op) final { // do not fuse through let. this->Update(op->var, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque); this->Update(op->body, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const IfNode* op) final { // do not fuse through if. this->Update(op->cond, nullptr, kOpaque); this->Update(op->true_branch, nullptr, kOpaque); this->Update(op->false_branch, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const RefCreateNode* op) final { this->Update(op->value, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const RefReadNode* op) final { this->Update(op->ref, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const RefWriteNode* op) final { this->Update(op->ref, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); this->AddNode(op); } void VisitExpr_(const MatchNode* op) final { this->Update(op->data, nullptr, kOpaque); for (const Clause& c : op->clauses) { this->Update(c->rhs, nullptr, kOpaque); } ExprVisitor::VisitExpr_(op); this->AddNode(op); }};SimpleIndexedForwardGraph SimpleIndexedForwardGraph::Create(common::Arena* arena, const Expr& body) { return Creator(arena).Prepare(body);}/*! * \brief Dominator tree that represent domination or * post domination relation of the node. * 该tree的顺序是post order,从下往上的顺序,跟原始的graph是一样的,跟IndexedForwardGraph是相反的 * 1 * | * 2 * / \ * 3 4 * \ / * 5 * | * 6(root) * * change to * * 1 * | * 2 * | * 3 | 4 * \ | / * 5 * | * 6(root) */class SimpleDominatorTree { public: /*! * \brief A node in the dominator tree. */ struct Node { /*! \brief The node in the tree 对应的IndexedForwardGraph图中的节点*/ SimpleIndexedForwardGraph::Node* gnode{nullptr}; /*! \brief parent of the tree 当前节点在树中的父节点*/ Node* parent{nullptr}; /*! \brief current depth 从根节点(以grpah的根节点对应的Node)到当前节点的层数*/ int depth{0}; /*! \brief aggregated pattern to parent */ OpPatternKind pattern{kOpaque}; }; // index -> node. 顺序跟IndexedForwardGraph中的post_dfs_order一样, // 最上方的叶节点在nodes开头,最下方的根节点在nodes末尾 std::vector<Node*> nodes; /*! * \brief compute a post dominator relation for a given dataflow graph. * \param arena The arena used for node allocation. * \param graph The graph to be analyze. * \return The dominator tree of the graph. * \note This algorithm makes use of the fact that graph is DAG, * and runs a single pass algorithm via LCA. */ static SimpleDominatorTree PostDom(common::Arena* arena, const SimpleIndexedForwardGraph& graph); private: // Combine pattern together. static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > rhs) return lhs; return rhs; } /*! * \brief Find the least common ancestor of the two nodes. * 找到两个节点的最近的共同祖先节点 * \param lhs The left node. * \param rhs The right node. * \param edge_pattern * The combined edge pattern across all the parents. * \return The least common ancestor of the two. */ static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { /* 如果左节点和右节点深度不一样,就将深度大的节点上浮一层 如果2个节点的深度一样了,就都上浮一层 直到2个节点相同,表示该节点就是原来2个节点的最近的祖先节点 */ while (lhs != rhs) { if (lhs == nullptr) return nullptr; if (rhs == nullptr) return nullptr; if (lhs->depth < rhs->depth) { edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); rhs = rhs->parent; } else if (rhs->depth < lhs->depth) { edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); lhs = lhs->parent; } else { edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); lhs = lhs->parent; rhs = rhs->parent; } } return lhs; } /*! * \brief Find the least common ancestor of a list of nodes. * \param nodes the nodes. * \param edge_pattern * The combined edge pattern across all the parents. * \return The least common ancestor of all nodes. */ Node* LeastCommonAncestor(const LinkedList<SimpleIndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) { auto link = input_nodes.head; if (link == nullptr) { return nullptr; } // auto get_node = [&](const SimpleIndexedForwardGraph::Edge& edge) { size_t oindex = edge.node->index; // IndexedForwardGraph中访问该Node的序号 CHECK_LT(oindex, nodes.size()); Node* onode = nodes[oindex]; CHECK(onode != nullptr); return onode; }; Node* parent = get_node(link->value); *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); link = link->next; for (; link != nullptr; link = link->next) { parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); } return parent; } /*! * \brief Convert the Node from an SimpleIndexedForwardGraph Node into DomaintorTree Node. * \param arena The Arena. * \param gnode An SimpleIndexedForwardGraph Node. * \return The SimpleDominatorTree Node. */ Node* GetNode(common::Arena* arena, SimpleIndexedForwardGraph::Node* gnode) { Node* tnode = arena->make<Node>(); tnode->gnode = gnode; if (gnode->extern_ref) { //是输出节点 tnode->depth = 1; //?? tnode->parent = nullptr; tnode->pattern = kOpaque; } else { // find the LCAs of all outputs. OpPatternKind pattern = kElemWise; Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); tnode->depth = parent ? parent->depth + 1 : 1; tnode->parent = parent; tnode->pattern = pattern; } return tnode; }};SimpleDominatorTree SimpleDominatorTree::PostDom(common::Arena* arena, const SimpleIndexedForwardGraph& graph) { SimpleDominatorTree tree; tree.nodes.resize(graph.post_dfs_order.size(), nullptr); // reverse topo order 从下往上的顺序(相对graph而言) // 因为graph.post_dfs_order的顺序是Forward,即从上往下的顺序 for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { size_t index = i - 1; tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); } return tree;}} // namespace relay} // namespace tvm