1. /*!
    2. *
    3. * \file src/relay/pass/vacc/forward_graph.h
    4. *
    5. * \brief This is a indexed data flow graph in forward direction.
    6. */
    7. #include <tvm/expr_operator.h>
    8. #include <tvm/relay/analysis.h>
    9. #include <tvm/relay/expr_functor.h>
    10. #include <tvm/relay/op_attr_types.h>
    11. #include <tvm/relay/transform.h>
    12. #include "../../../common/arena.h"
    13. #include "../pattern_util.h"
    14. namespace tvm {
    15. namespace relay {
    16. using common::LinkedList;
    17. using common::LinkNode;
    18. /*!
    19. * \brief Indexed data flow graph in forward direction.
    20. * This is a temporary data structure used for operator fusion analysis.
    21. *
    22. * This data structure only captures the dataflow fragement and
    23. * could ignore blocks like let by simply ordering each dataflow block
    24. * and mark the output node as extern_ref;
    25. */
    26. /*
    27. 说明:
    28. IndexedForwardGraph对应的树和tvm中的graph是相反的。
    29. graph遍历的入口是FunctionNode,在访问该FunctionNode时会将与它相关的tvm::Node通过Update添加到graph_.node_map中
    30. 这样一来,递归访问上面的tvm::Node的时候就能通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)
    31. 一般的,在访问某个tvm::Node时(以CallNode为例):
    32. 1 通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)
    33. 2 将当前tvm::Node所引用的tvm::Node(包括op对应的OpNode,
    34. args中的各个expr对应的Node),通过Update添加到graph_.node_map中(其中的参数parent就是1中获取的Node)
    35. 这样一来,递归访问上面的tvm::Node的时候就能通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)
    36. 3 调用ExprVisitor::VisitExpr_(call);来递归访问上面grpah
    37. 4
    38. 待当前tvm::Node之前的graph全部访问完毕,调用this->AddNode(call);
    39. 将当前tvm::Node对应的Node(SimpleIndexedForwardGraph::Node)添加到graph_.node_mappost_dfs_order
    40. 所以,graph_.node_mappost_dfs_order中,原graph的叶节点对应位置0,而原graph的叶节点对应最后位置
    41. */
    42. class SimpleIndexedForwardGraph {
    43. public:
    44. struct Node;
    45. /*!
    46. * The forward edge in the dataflow graph.
    47. */
    48. struct Edge {
    49. /*! \brief The corresponding node */
    50. Node* node{nullptr};
    51. /*! \brief The respective pattern of this op */
    52. OpPatternKind pattern{kOpaque};
    53. };
    54. /*! \brief A node in the graph. */
    55. struct Node {
    56. /*! \brief weak reference to the corresponding edge. */
    57. const tvm::Node* ref{nullptr};
    58. /*! \brief The index of the node in topological order. */
    59. size_t index{0};
    60. /*! \brief Whether this node is referenced by external source, 即是否是输出节点 */
    61. bool extern_ref{false};
    62. /*! \brief The general pattern in the node */
    63. OpPatternKind pattern{kOpaque};
    64. /*! \brief The outputs of the node. */
    65. // 在graph中从上往下为Forward顺序,outputs就是引用当前Node的其它Node
    66. LinkedList<Edge> outputs;
    67. /*!
    68. * \brief Get all tvm::Node which refer to this node
    69. * \return std::vector<tvm::Node *>
    70. */
    71. std::vector<const tvm::Node*> GetRefs() {
    72. std::vector<const tvm::Node*> nodes;
    73. for (auto* link = outputs.head; link != nullptr; link = link->next) {
    74. nodes.push_back(link->value.node->ref);
    75. }
    76. return std::move(nodes);
    77. }
    78. };
    79. /*! \brief The node map that maps node to graph */
    80. // 给定tvm::Node, 可以通过node_map找到对应的IndexedForwardGraph中的Node
    81. // 然后根据Node的outputs找到引用tvm::Node的一个或多个tvm::Node
    82. std::unordered_map<const tvm::Node*, Node*> node_map;
    83. /*! \brief All the nodes in post DFS order */
    84. // graph中从上往下的顺序
    85. std::vector<Node*> post_dfs_order;
    86. /*! \brief Dump the graph into string. */
    87. void DebugDump() {
    88. std::ostringstream os;
    89. for (size_t i = 0; i < post_dfs_order.size(); ++i) {
    90. Node* node = post_dfs_order[i];
    91. os << "node[" << i << "], " << GetRef<NodeRef>(node->ref) << " outputs=[";
    92. for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
    93. os << link->value.node->index << ", ";
    94. }
    95. os << "]\n";
    96. }
    97. LOG(INFO) << os.str();
    98. }
    99. /*!
    100. * \brief create a indexed forward graph.
    101. * \param arena The arena used for data allocation.
    102. * \param body The body of the expression to create a graph.
    103. */
    104. static SimpleIndexedForwardGraph Create(common::Arena* arena, const Expr& body);
    105. private:
    106. class Creator;
    107. };
    108. // Creator of post dominator tree of the dataflow
    109. class SimpleIndexedForwardGraph::Creator : private ExprVisitor {
    110. public:
    111. explicit Creator(common::Arena* arena) : arena_(arena) {}
    112. SimpleIndexedForwardGraph Prepare(const Expr& body) {
    113. this->Update(body, nullptr, kOpaque);
    114. this->VisitExpr(body);
    115. return std::move(graph_);
    116. }
    117. private:
    118. /*! \brief allocator of all the internal node object */
    119. common::Arena* arena_;
    120. // The output.
    121. SimpleIndexedForwardGraph graph_;
    122. // attribute equal comparator
    123. AttrsEqual attr_equal_;
    124. // Update the message stored at the node.
    125. // 更新graph_.node_map
    126. // 其实就是给node对应的IndexedForwardGraph::Node添加一条输出边IndexedForwardGraph::Edge
    127. // 在当前节点就将输入节点添加到graph_.node_map中,同时将当前节点添加到输入节点的outputs中
    128. void Update(const Expr& node, SimpleIndexedForwardGraph::Node* parent, OpPatternKind pattern) {
    129. // 先根据node对应的tvm::Node在graph_找到对应的IndexedForwardGraph::Node* current
    130. // 如果没有找到就新建一个
    131. const tvm::Node* key = node.get();
    132. SimpleIndexedForwardGraph::Node* current;
    133. auto it = graph_.node_map.find(key);
    134. if (it != graph_.node_map.end()) {
    135. current = it->second;
    136. } else {
    137. current = arena_->make<SimpleIndexedForwardGraph::Node>();
    138. graph_.node_map[key] = current;
    139. }
    140. if (parent != nullptr) {
    141. auto* link = arena_->make<LinkNode<SimpleIndexedForwardGraph::Edge> >();
    142. link->value.node = parent;
    143. link->value.pattern = pattern;
    144. current->outputs.Push(link);
    145. } else {
    146. current->extern_ref = true; //当前Node是输出节点
    147. }
    148. }
    149. // 添加一个tvm::Node
    150. // 必须确保该tvm::Node对应的IndexedForwardGraph::Node存在,且IndexedForwardGraph::Node没有被引用。
    151. // 然后更新该IndexedForwardGraph::Node的ref, index, 并添加到graph_.post_dfs_order中
    152. // 在访问tvm::Node才调用,添加当前的Node
    153. void AddNode(const tvm::Node* key) {
    154. auto it = graph_.node_map.find(key);
    155. CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<NodeRef>(key);
    156. SimpleIndexedForwardGraph::Node* node = it->second;
    157. CHECK(node->ref == nullptr);
    158. node->ref = key;
    159. node->index = graph_.post_dfs_order.size();
    160. graph_.post_dfs_order.push_back(node);
    161. }
    162. // Post order tree
    163. void VisitExpr_(const FunctionNode* op) final {
    164. for (auto param : op->params) {
    165. this->Update(param, nullptr, kOpaque);
    166. }
    167. this->Update(op->body, nullptr, kOpaque);
    168. ExprVisitor::VisitExpr_(op);
    169. }
    170. void VisitExpr_(const ConstantNode* op) final {
    171. //在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map
    172. this->AddNode(op);
    173. Node* node = graph_.node_map.at(op);
    174. DataType dtype = DataType(op->data->dtype);
    175. // This rule must be consistent with code generator.
    176. bool is_simple_const =
    177. (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) ||
    178. dtype == DataType::Float(64) || dtype == DataType::Bool());
    179. if (op->is_scalar() && is_simple_const) {
    180. node->pattern = kElemWise;
    181. } else {
    182. // for now, mark non-scalar constant
    183. // as opaque, we will not choose to fuse it.
    184. node->pattern = kOpaque;
    185. }
    186. }
    187. void VisitExpr_(const CallNode* call) final {
    188. //在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map
    189. CHECK(graph_.node_map.count(call));
    190. Node* node = graph_.node_map.at(call);
    191. static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
    192. // Now we set the pattern of this call.
    193. //
    194. // If we see a call mentioning an operator we should mark it with its
    195. // annotated pattern.
    196. //
    197. // If the pattern is not annotated we will default to opaque.
    198. //
    199. // Finally if the operator position is not a call node we will
    200. // need to call Update, as it may be an arbitrary expression.
    201. OpPatternKind op_pattern = kOpaque;
    202. const OpNode* opnode = call->op.as<OpNode>();
    203. if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
    204. op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
    205. // if(opnode->name == "nn.max_pool2d") op_pattern = kElemWise;
    206. } else {
    207. this->Update(call->op, node, kOpaque);
    208. }
    209. node->pattern = op_pattern;
    210. this->Update(call->op, nullptr, kOpaque); // OpNode没有输出边
    211. const auto* rtype = call->checked_type().as<TensorTypeNode>();
    212. // pass the analysis back to all the children it references.
    213. for (size_t i = 0; i < call->args.size(); ++i) {
    214. const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
    215. // specifically check if result type is the same as arguments type
    216. OpPatternKind edge_pattern = op_pattern;
    217. if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
    218. attr_equal_(rtype->shape, arg_type->shape)) {
    219. edge_pattern = kElemWise;
    220. }
    221. // 当前节点的输入节点的输出节点就是当前节点
    222. this->Update(call->args[i], node, edge_pattern);
    223. }
    224. ExprVisitor::VisitExpr_(call);
    225. this->AddNode(call);
    226. }
    227. void VisitExpr_(const TupleNode* op) final {
    228. //在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map
    229. CHECK(graph_.node_map.count(op));
    230. Node* tuple_node = graph_.node_map.at(op);
    231. tuple_node->pattern = kTuple;
    232. for (const Expr& field : op->fields) {
    233. if (field->checked_type().as<TensorTypeNode>()) {
    234. this->Update(field, tuple_node, kInjective);
    235. } else {
    236. this->Update(field, nullptr, kOpaque);
    237. }
    238. }
    239. ExprVisitor::VisitExpr_(op);
    240. this->AddNode(op);
    241. }
    242. void VisitExpr_(const TupleGetItemNode* op) final {
    243. auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
    244. CHECK(tuple_type);
    245. // when TVM lowers a fused function, it expects all arguments to be a Tensor or
    246. // a tuple containing only Tensors. But this tuple may contain a reference or
    247. // another tuple. To avoid modifying codegen logic, we do not allow fusing through this node
    248. // if the tuple contains such non Tensor fields. However, all fields will be recursively
    249. // visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods.
    250. bool has_non_tensor = false;
    251. for (auto ty : tuple_type->fields) {
    252. if (!ty.as<TensorTypeNode>()) {
    253. has_non_tensor = true;
    254. break;
    255. }
    256. }
    257. if (has_non_tensor) {
    258. this->Update(op->tuple, nullptr, kOpaque);
    259. } else {
    260. CHECK(graph_.node_map.count(op));
    261. Node* node = graph_.node_map.at(op);
    262. node->pattern = kInjective;
    263. this->Update(op->tuple, node, kInjective);
    264. }
    265. ExprVisitor::VisitExpr_(op);
    266. this->AddNode(op);
    267. }
    268. void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
    269. void VisitExpr_(const LetNode* op) final {
    270. // do not fuse through let.
    271. this->Update(op->var, nullptr, kOpaque);
    272. this->Update(op->value, nullptr, kOpaque);
    273. this->Update(op->body, nullptr, kOpaque);
    274. ExprVisitor::VisitExpr_(op);
    275. this->AddNode(op);
    276. }
    277. void VisitExpr_(const IfNode* op) final {
    278. // do not fuse through if.
    279. this->Update(op->cond, nullptr, kOpaque);
    280. this->Update(op->true_branch, nullptr, kOpaque);
    281. this->Update(op->false_branch, nullptr, kOpaque);
    282. ExprVisitor::VisitExpr_(op);
    283. this->AddNode(op);
    284. }
    285. void VisitExpr_(const RefCreateNode* op) final {
    286. this->Update(op->value, nullptr, kOpaque);
    287. ExprVisitor::VisitExpr_(op);
    288. this->AddNode(op);
    289. }
    290. void VisitExpr_(const RefReadNode* op) final {
    291. this->Update(op->ref, nullptr, kOpaque);
    292. ExprVisitor::VisitExpr_(op);
    293. this->AddNode(op);
    294. }
    295. void VisitExpr_(const RefWriteNode* op) final {
    296. this->Update(op->ref, nullptr, kOpaque);
    297. this->Update(op->value, nullptr, kOpaque);
    298. ExprVisitor::VisitExpr_(op);
    299. this->AddNode(op);
    300. }
    301. void VisitExpr_(const MatchNode* op) final {
    302. this->Update(op->data, nullptr, kOpaque);
    303. for (const Clause& c : op->clauses) {
    304. this->Update(c->rhs, nullptr, kOpaque);
    305. }
    306. ExprVisitor::VisitExpr_(op);
    307. this->AddNode(op);
    308. }
    309. };
    310. SimpleIndexedForwardGraph SimpleIndexedForwardGraph::Create(common::Arena* arena, const Expr& body) {
    311. return Creator(arena).Prepare(body);
    312. }
    313. /*!
    314. * \brief Dominator tree that represent domination or
    315. * post domination relation of the node.
    316. * 该tree的顺序是post order,从下往上的顺序,跟原始的graph是一样的,跟IndexedForwardGraph是相反的
    317. * 1
    318. * |
    319. * 2
    320. * / \
    321. * 3 4
    322. * \ /
    323. * 5
    324. * |
    325. * 6(root)
    326. *
    327. * change to
    328. *
    329. * 1
    330. * |
    331. * 2
    332. * |
    333. * 3 | 4
    334. * \ | /
    335. * 5
    336. * |
    337. * 6(root)
    338. */
    339. class SimpleDominatorTree {
    340. public:
    341. /*!
    342. * \brief A node in the dominator tree.
    343. */
    344. struct Node {
    345. /*! \brief The node in the tree 对应的IndexedForwardGraph图中的节点*/
    346. SimpleIndexedForwardGraph::Node* gnode{nullptr};
    347. /*! \brief parent of the tree 当前节点在树中的父节点*/
    348. Node* parent{nullptr};
    349. /*! \brief current depth 从根节点(以grpah的根节点对应的Node)到当前节点的层数*/
    350. int depth{0};
    351. /*! \brief aggregated pattern to parent */
    352. OpPatternKind pattern{kOpaque};
    353. };
    354. // index -> node. 顺序跟IndexedForwardGraph中的post_dfs_order一样,
    355. // 最上方的叶节点在nodes开头,最下方的根节点在nodes末尾
    356. std::vector<Node*> nodes;
    357. /*!
    358. * \brief compute a post dominator relation for a given dataflow graph.
    359. * \param arena The arena used for node allocation.
    360. * \param graph The graph to be analyze.
    361. * \return The dominator tree of the graph.
    362. * \note This algorithm makes use of the fact that graph is DAG,
    363. * and runs a single pass algorithm via LCA.
    364. */
    365. static SimpleDominatorTree PostDom(common::Arena* arena, const SimpleIndexedForwardGraph& graph);
    366. private:
    367. // Combine pattern together.
    368. static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
    369. if (lhs > rhs) return lhs;
    370. return rhs;
    371. }
    372. /*!
    373. * \brief Find the least common ancestor of the two nodes.
    374. * 找到两个节点的最近的共同祖先节点
    375. * \param lhs The left node.
    376. * \param rhs The right node.
    377. * \param edge_pattern
    378. * The combined edge pattern across all the parents.
    379. * \return The least common ancestor of the two.
    380. */
    381. static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) {
    382. /*
    383. 如果左节点和右节点深度不一样,就将深度大的节点上浮一层
    384. 如果2个节点的深度一样了,就都上浮一层
    385. 直到2个节点相同,表示该节点就是原来2个节点的最近的祖先节点
    386. */
    387. while (lhs != rhs) {
    388. if (lhs == nullptr) return nullptr;
    389. if (rhs == nullptr) return nullptr;
    390. if (lhs->depth < rhs->depth) {
    391. edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
    392. rhs = rhs->parent;
    393. } else if (rhs->depth < lhs->depth) {
    394. edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
    395. lhs = lhs->parent;
    396. } else {
    397. edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
    398. edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
    399. lhs = lhs->parent;
    400. rhs = rhs->parent;
    401. }
    402. }
    403. return lhs;
    404. }
    405. /*!
    406. * \brief Find the least common ancestor of a list of nodes.
    407. * \param nodes the nodes.
    408. * \param edge_pattern
    409. * The combined edge pattern across all the parents.
    410. * \return The least common ancestor of all nodes.
    411. */
    412. Node* LeastCommonAncestor(const LinkedList<SimpleIndexedForwardGraph::Edge>& input_nodes,
    413. OpPatternKind* edge_pattern) {
    414. auto link = input_nodes.head;
    415. if (link == nullptr) {
    416. return nullptr;
    417. }
    418. //
    419. auto get_node = [&](const SimpleIndexedForwardGraph::Edge& edge) {
    420. size_t oindex = edge.node->index; // IndexedForwardGraph中访问该Node的序号
    421. CHECK_LT(oindex, nodes.size());
    422. Node* onode = nodes[oindex];
    423. CHECK(onode != nullptr);
    424. return onode;
    425. };
    426. Node* parent = get_node(link->value);
    427. *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
    428. link = link->next;
    429. for (; link != nullptr; link = link->next) {
    430. parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
    431. *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
    432. }
    433. return parent;
    434. }
    435. /*!
    436. * \brief Convert the Node from an SimpleIndexedForwardGraph Node into DomaintorTree Node.
    437. * \param arena The Arena.
    438. * \param gnode An SimpleIndexedForwardGraph Node.
    439. * \return The SimpleDominatorTree Node.
    440. */
    441. Node* GetNode(common::Arena* arena, SimpleIndexedForwardGraph::Node* gnode) {
    442. Node* tnode = arena->make<Node>();
    443. tnode->gnode = gnode;
    444. if (gnode->extern_ref) {
    445. //是输出节点
    446. tnode->depth = 1; //??
    447. tnode->parent = nullptr;
    448. tnode->pattern = kOpaque;
    449. } else {
    450. // find the LCAs of all outputs.
    451. OpPatternKind pattern = kElemWise;
    452. Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
    453. tnode->depth = parent ? parent->depth + 1 : 1;
    454. tnode->parent = parent;
    455. tnode->pattern = pattern;
    456. }
    457. return tnode;
    458. }
    459. };
    460. SimpleDominatorTree SimpleDominatorTree::PostDom(common::Arena* arena, const SimpleIndexedForwardGraph& graph) {
    461. SimpleDominatorTree tree;
    462. tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
    463. // reverse topo order 从下往上的顺序(相对graph而言)
    464. // 因为graph.post_dfs_order的顺序是Forward,即从上往下的顺序
    465. for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
    466. size_t index = i - 1;
    467. tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
    468. }
    469. return tree;
    470. }
    471. } // namespace relay
    472. } // namespace tvm