/*!
*
* \file src/relay/pass/vacc/forward_graph.h
*
* \brief This is a indexed data flow graph in forward direction.
*/
#include <tvm/expr_operator.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include "../../../common/arena.h"
#include "../pattern_util.h"
namespace tvm {
namespace relay {
using common::LinkedList;
using common::LinkNode;
/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
*
* This data structure only captures the dataflow fragement and
* could ignore blocks like let by simply ordering each dataflow block
* and mark the output node as extern_ref;
*/
/*
说明:
IndexedForwardGraph对应的树和tvm中的graph是相反的。
graph遍历的入口是FunctionNode,在访问该FunctionNode时会将与它相关的tvm::Node通过Update添加到graph_.node_map中
这样一来,递归访问上面的tvm::Node的时候就能通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)
一般的,在访问某个tvm::Node时(以CallNode为例):
1 通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)
2 将当前tvm::Node所引用的tvm::Node(包括op对应的OpNode,
args中的各个expr对应的Node),通过Update添加到graph_.node_map中(其中的参数parent就是1中获取的Node)
这样一来,递归访问上面的tvm::Node的时候就能通过graph_.node_map找到与当前节点匹配的Node(SimpleIndexedForwardGraph::Node)
3 调用ExprVisitor::VisitExpr_(call);来递归访问上面grpah
4
待当前tvm::Node之前的graph全部访问完毕,调用this->AddNode(call);
将当前tvm::Node对应的Node(SimpleIndexedForwardGraph::Node)添加到graph_.node_mappost_dfs_order
所以,graph_.node_mappost_dfs_order中,原graph的叶节点对应位置0,而原graph的叶节点对应最后位置
*/
class SimpleIndexedForwardGraph {
public:
struct Node;
/*!
* The forward edge in the dataflow graph.
*/
struct Edge {
/*! \brief The corresponding node */
Node* node{nullptr};
/*! \brief The respective pattern of this op */
OpPatternKind pattern{kOpaque};
};
/*! \brief A node in the graph. */
struct Node {
/*! \brief weak reference to the corresponding edge. */
const tvm::Node* ref{nullptr};
/*! \brief The index of the node in topological order. */
size_t index{0};
/*! \brief Whether this node is referenced by external source, 即是否是输出节点 */
bool extern_ref{false};
/*! \brief The general pattern in the node */
OpPatternKind pattern{kOpaque};
/*! \brief The outputs of the node. */
// 在graph中从上往下为Forward顺序,outputs就是引用当前Node的其它Node
LinkedList<Edge> outputs;
/*!
* \brief Get all tvm::Node which refer to this node
* \return std::vector<tvm::Node *>
*/
std::vector<const tvm::Node*> GetRefs() {
std::vector<const tvm::Node*> nodes;
for (auto* link = outputs.head; link != nullptr; link = link->next) {
nodes.push_back(link->value.node->ref);
}
return std::move(nodes);
}
};
/*! \brief The node map that maps node to graph */
// 给定tvm::Node, 可以通过node_map找到对应的IndexedForwardGraph中的Node
// 然后根据Node的outputs找到引用tvm::Node的一个或多个tvm::Node
std::unordered_map<const tvm::Node*, Node*> node_map;
/*! \brief All the nodes in post DFS order */
// graph中从上往下的顺序
std::vector<Node*> post_dfs_order;
/*! \brief Dump the graph into string. */
void DebugDump() {
std::ostringstream os;
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
Node* node = post_dfs_order[i];
os << "node[" << i << "], " << GetRef<NodeRef>(node->ref) << " outputs=[";
for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
os << link->value.node->index << ", ";
}
os << "]\n";
}
LOG(INFO) << os.str();
}
/*!
* \brief create a indexed forward graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static SimpleIndexedForwardGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
// Creator of post dominator tree of the dataflow
class SimpleIndexedForwardGraph::Creator : private ExprVisitor {
public:
explicit Creator(common::Arena* arena) : arena_(arena) {}
SimpleIndexedForwardGraph Prepare(const Expr& body) {
this->Update(body, nullptr, kOpaque);
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
SimpleIndexedForwardGraph graph_;
// attribute equal comparator
AttrsEqual attr_equal_;
// Update the message stored at the node.
// 更新graph_.node_map
// 其实就是给node对应的IndexedForwardGraph::Node添加一条输出边IndexedForwardGraph::Edge
// 在当前节点就将输入节点添加到graph_.node_map中,同时将当前节点添加到输入节点的outputs中
void Update(const Expr& node, SimpleIndexedForwardGraph::Node* parent, OpPatternKind pattern) {
// 先根据node对应的tvm::Node在graph_找到对应的IndexedForwardGraph::Node* current
// 如果没有找到就新建一个
const tvm::Node* key = node.get();
SimpleIndexedForwardGraph::Node* current;
auto it = graph_.node_map.find(key);
if (it != graph_.node_map.end()) {
current = it->second;
} else {
current = arena_->make<SimpleIndexedForwardGraph::Node>();
graph_.node_map[key] = current;
}
if (parent != nullptr) {
auto* link = arena_->make<LinkNode<SimpleIndexedForwardGraph::Edge> >();
link->value.node = parent;
link->value.pattern = pattern;
current->outputs.Push(link);
} else {
current->extern_ref = true; //当前Node是输出节点
}
}
// 添加一个tvm::Node
// 必须确保该tvm::Node对应的IndexedForwardGraph::Node存在,且IndexedForwardGraph::Node没有被引用。
// 然后更新该IndexedForwardGraph::Node的ref, index, 并添加到graph_.post_dfs_order中
// 在访问tvm::Node才调用,添加当前的Node
void AddNode(const tvm::Node* key) {
auto it = graph_.node_map.find(key);
CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<NodeRef>(key);
SimpleIndexedForwardGraph::Node* node = it->second;
CHECK(node->ref == nullptr);
node->ref = key;
node->index = graph_.post_dfs_order.size();
graph_.post_dfs_order.push_back(node);
}
// Post order tree
void VisitExpr_(const FunctionNode* op) final {
for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const ConstantNode* op) final {
//在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map
this->AddNode(op);
Node* node = graph_.node_map.at(op);
DataType dtype = DataType(op->data->dtype);
// This rule must be consistent with code generator.
bool is_simple_const =
(dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) ||
dtype == DataType::Float(64) || dtype == DataType::Bool());
if (op->is_scalar() && is_simple_const) {
node->pattern = kElemWise;
} else {
// for now, mark non-scalar constant
// as opaque, we will not choose to fuse it.
node->pattern = kOpaque;
}
}
void VisitExpr_(const CallNode* call) final {
//在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// Now we set the pattern of this call.
//
// If we see a call mentioning an operator we should mark it with its
// annotated pattern.
//
// If the pattern is not annotated we will default to opaque.
//
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
const OpNode* opnode = call->op.as<OpNode>();
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
// if(opnode->name == "nn.max_pool2d") op_pattern = kElemWise;
} else {
this->Update(call->op, node, kOpaque);
}
node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque); // OpNode没有输出边
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
attr_equal_(rtype->shape, arg_type->shape)) {
edge_pattern = kElemWise;
}
// 当前节点的输入节点的输出节点就是当前节点
this->Update(call->args[i], node, edge_pattern);
}
ExprVisitor::VisitExpr_(call);
this->AddNode(call);
}
void VisitExpr_(const TupleNode* op) final {
//在上一个Node(引用当前Node)中通过Update()已经将当前Node添加到graph_.node_map
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kTuple;
for (const Expr& field : op->fields) {
if (field->checked_type().as<TensorTypeNode>()) {
this->Update(field, tuple_node, kInjective);
} else {
this->Update(field, nullptr, kOpaque);
}
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const TupleGetItemNode* op) final {
auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
CHECK(tuple_type);
// when TVM lowers a fused function, it expects all arguments to be a Tensor or
// a tuple containing only Tensors. But this tuple may contain a reference or
// another tuple. To avoid modifying codegen logic, we do not allow fusing through this node
// if the tuple contains such non Tensor fields. However, all fields will be recursively
// visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods.
bool has_non_tensor = false;
for (auto ty : tuple_type->fields) {
if (!ty.as<TensorTypeNode>()) {
has_non_tensor = true;
break;
}
}
if (has_non_tensor) {
this->Update(op->tuple, nullptr, kOpaque);
} else {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
node->pattern = kInjective;
this->Update(op->tuple, node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
void VisitExpr_(const LetNode* op) final {
// do not fuse through let.
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const IfNode* op) final {
// do not fuse through if.
this->Update(op->cond, nullptr, kOpaque);
this->Update(op->true_branch, nullptr, kOpaque);
this->Update(op->false_branch, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefCreateNode* op) final {
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefReadNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefWriteNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const MatchNode* op) final {
this->Update(op->data, nullptr, kOpaque);
for (const Clause& c : op->clauses) {
this->Update(c->rhs, nullptr, kOpaque);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
};
SimpleIndexedForwardGraph SimpleIndexedForwardGraph::Create(common::Arena* arena, const Expr& body) {
return Creator(arena).Prepare(body);
}
/*!
* \brief Dominator tree that represent domination or
* post domination relation of the node.
* 该tree的顺序是post order,从下往上的顺序,跟原始的graph是一样的,跟IndexedForwardGraph是相反的
* 1
* |
* 2
* / \
* 3 4
* \ /
* 5
* |
* 6(root)
*
* change to
*
* 1
* |
* 2
* |
* 3 | 4
* \ | /
* 5
* |
* 6(root)
*/
class SimpleDominatorTree {
public:
/*!
* \brief A node in the dominator tree.
*/
struct Node {
/*! \brief The node in the tree 对应的IndexedForwardGraph图中的节点*/
SimpleIndexedForwardGraph::Node* gnode{nullptr};
/*! \brief parent of the tree 当前节点在树中的父节点*/
Node* parent{nullptr};
/*! \brief current depth 从根节点(以grpah的根节点对应的Node)到当前节点的层数*/
int depth{0};
/*! \brief aggregated pattern to parent */
OpPatternKind pattern{kOpaque};
};
// index -> node. 顺序跟IndexedForwardGraph中的post_dfs_order一样,
// 最上方的叶节点在nodes开头,最下方的根节点在nodes末尾
std::vector<Node*> nodes;
/*!
* \brief compute a post dominator relation for a given dataflow graph.
* \param arena The arena used for node allocation.
* \param graph The graph to be analyze.
* \return The dominator tree of the graph.
* \note This algorithm makes use of the fact that graph is DAG,
* and runs a single pass algorithm via LCA.
*/
static SimpleDominatorTree PostDom(common::Arena* arena, const SimpleIndexedForwardGraph& graph);
private:
// Combine pattern together.
static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
if (lhs > rhs) return lhs;
return rhs;
}
/*!
* \brief Find the least common ancestor of the two nodes.
* 找到两个节点的最近的共同祖先节点
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of the two.
*/
static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) {
/*
如果左节点和右节点深度不一样,就将深度大的节点上浮一层
如果2个节点的深度一样了,就都上浮一层
直到2个节点相同,表示该节点就是原来2个节点的最近的祖先节点
*/
while (lhs != rhs) {
if (lhs == nullptr) return nullptr;
if (rhs == nullptr) return nullptr;
if (lhs->depth < rhs->depth) {
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
rhs = rhs->parent;
} else if (rhs->depth < lhs->depth) {
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
lhs = lhs->parent;
} else {
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
lhs = lhs->parent;
rhs = rhs->parent;
}
}
return lhs;
}
/*!
* \brief Find the least common ancestor of a list of nodes.
* \param nodes the nodes.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of all nodes.
*/
Node* LeastCommonAncestor(const LinkedList<SimpleIndexedForwardGraph::Edge>& input_nodes,
OpPatternKind* edge_pattern) {
auto link = input_nodes.head;
if (link == nullptr) {
return nullptr;
}
//
auto get_node = [&](const SimpleIndexedForwardGraph::Edge& edge) {
size_t oindex = edge.node->index; // IndexedForwardGraph中访问该Node的序号
CHECK_LT(oindex, nodes.size());
Node* onode = nodes[oindex];
CHECK(onode != nullptr);
return onode;
};
Node* parent = get_node(link->value);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
link = link->next;
for (; link != nullptr; link = link->next) {
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
}
return parent;
}
/*!
* \brief Convert the Node from an SimpleIndexedForwardGraph Node into DomaintorTree Node.
* \param arena The Arena.
* \param gnode An SimpleIndexedForwardGraph Node.
* \return The SimpleDominatorTree Node.
*/
Node* GetNode(common::Arena* arena, SimpleIndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>();
tnode->gnode = gnode;
if (gnode->extern_ref) {
//是输出节点
tnode->depth = 1; //??
tnode->parent = nullptr;
tnode->pattern = kOpaque;
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern;
}
return tnode;
}
};
SimpleDominatorTree SimpleDominatorTree::PostDom(common::Arena* arena, const SimpleIndexedForwardGraph& graph) {
SimpleDominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order 从下往上的顺序(相对graph而言)
// 因为graph.post_dfs_order的顺序是Forward,即从上往下的顺序
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
}
return tree;
}
} // namespace relay
} // namespace tvm