TVM学习(八)pass总结

电子说

1.3w人已加入

描述

什么是pass?

Pass是TVM中基于relay IR进行的优化,目的是去除冗余算子,进行硬件友好的算子转换,最终能够提高硬件运行效率。由tensorflow等深度学习框架生成的图机构中,含有很多可以优化的算子,比如expand_dim,len等,其实在编译阶段完全可以优化掉,从而能够减少硬件的计算,以及避免出现硬件不支持的算子。

TVM中在include/tvm/ir/transform.h中对pass进行了抽象,主要包括PassContext,PassInfo,Pass,以及Sequential。其中PassContext包含了pass执行依赖的一些参数,比如优化level,analysis report等。PassInfo是一个用于记录pass信息的类,包括pass的opt-level,名称等。和PassContext的区别是PassContext是pass执行所需要获取的条件。Pass就是执行pass的主体,主要就是pass的函数。比如RemoveUnusedFunctions就是执行pass的一个主体函数,目的就是去除冗余算子。Sequential是一个container,装载所有pass。

一些pass

01. RemoveUnusedFunctions

位于src/relay/backend/vm/removed_unused_funcs.cc中,顾名思义就是去除relay IR中的冗余函数。通过从main函数开始遍历,如果一个函数体没有引用其它函数,而同时又没有被其它函数调用,即从relay图上看是一个孤立算子,那么就从IRModule中删除。

 void VisitExpr_(const FunctionNode* func_node) final {
    auto func = GetRef(func_node);
    if (visiting_.find(func) == visiting_.end()) {
      visiting_.insert(func);
      for (auto param : func_node->params) {
        ExprVisitor::VisitExpr(param);
      }
      ExprVisitor::VisitExpr(func_node-> body);
    }
  }

02. ToBasicBlockNormalForm

函数在文件src/relay/trnaforms/to_basic_block_normal_from.cc中。通过遍历IRModule中的每个function,将每个function转换为基本块形式。转换函数是ToBasicBlockNormalFormAux。这个函数包括两个步骤:一是找到基本块(basic block)的边界,TVM中对边界进行了一步抽象,判断每个expr是否属于同一个scope,如果scope相同那么就可以将这些表达式放在一个基本块中;第二步根据每个表达式所属的scope将表达式归属到一个基本块中。

Expr ToBasicBlockNormalFormAux(const Expr& e) {
  // calculate all the dependency between nodes.
  support::Arena arena;
  DependencyGraph dg = DependencyGraph::Create(&arena, e);
  /* The scope of the whole expr is global.
   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
   * We also record the set of expressions whose scope is lifted.
   */
  std::pair scopes = CalcScope(dg);
  return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}

DependencyGraph是一个表达式相互依赖的图结构,通过遍历图中每个节点,找到每个节点的scope。CalcScope在文件src/relay/transforms/to_a_normal_from.cc中。这个函数中重点关注以下代码:

…
        s = LCA(s, expr_scope.at(iit->value));
…
    if (n->new_scope) {
      auto child_scope = std::make_shared(s);
      expr_scope.insert({n, child_scope});
    } else {
      expr_scope.insert({n, s});
}

LCA是获得当前节点的父节点的scope的LCA(least common ancestor),然后将这个scope作为这个节点的scope。了解基本块原理的都知道,寻找基本块首先要找到首指令的位置,然后一个首指令到下一个首指令之间的指令就属于一个基本块。而首指令就是那些具有条件和无条件跳转的指令。在TVM中通过new_scope来标记这些节点,比如Ifnode,FunctionNode,LetNode在建立dependency图的时候,这些节点就被标记为new_scope。这样就建立了dependency节点到scope节点的对应map。同时scope节点也被建立起树结构。

接下来就是建立Fill类,这个类中包含了dependency图以及scope的信息,通过其函数ToBasicBlockNormalForm实现基本块转换。它的基本逻辑通过VisitExpr函数遍历dependency节点,将具有相同scope的节点压入到同一个let_list中。Let_list文档中是这样解释的:

/*!
 * \file let_list.h
 * \brief LetList record let binding and insert let expression implicitly.
 *  using it, one can treat AST as value instead of expression,
 *  and pass them around freely without fear of AST explosion (or effect duplication).
 *  for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
 *  if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
 *  the AST will contain 2 'a', as b and c are now variables.

Let_list使得抽象语法树简洁化,不会因为变量的复制导致树的爆炸。具有相同的scope的expr被约束到相同的let_list中,用一个var来表达,这样就将表达式转化为var的形式。一个var也就对应了一个基本块。

03. Legalize

Legalize是实现等价函数的转换。主要代码在src/relay/transforms/legalize.cc中。主函数是:

Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
  auto rewriter = Legalizer(legalize_map_attr_name);
  return PostOrderRewrite(expr, &rewriter);
}

在legalize.cc文件中定义了一个继承了ExprRewriter的类,在这个类中实现了对function的替换。我们追踪一下调用的过程。PostOrderRewrite在文件src/relay/ir/expr_functor.cc中。首先建立一个PostOrderRewriter类,然后访问每个节点。在访问节点过程中调用了ExpandDataFlow函数,看一下这个函数的描述:

*
 * ExpandDataflow manually manages a stack and performs DFS to determine the processing
 * order of nodes in an input graph.
 *
 * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
 * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
 * and continues iteratively to process the top of the stack. When it finds a node that doesn't
 * match the dataflow types, or a node who's inputs have all been processed, it visits the current
 * leaf via fvisit_leaf.
 *
 * This function should be used internally to other classes to implement mixed-mode traversals. The
 * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
 * hits a non-dataflow node.
 *
 * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
 */

主要目的是有区别的去处理graph中的节点,如果fcheck_visited已经确定该节点处理过或者不需要处理,就跳过,通过fvisit_leaf继续访问下一个节点。而在VisitLeaf函数中就调用了legalizer类中的rewrite_函数实现了legalize功能。在Rewrite_中,通过映射表legalize_map_attr_name实现函数的等价转换。

04. SimplifyInference

实现对batch normalization, layer normalization, instance normalization, group normalization, L2 normalization算子的分解,这样做的目的是可以在之后的优化中,将这些算子融合到其它算子上,减少计算量。代码在src/relay/transforms/simplify_inference.cc中。文件中定义了一个InferenceSimplifier类来处理这个问题。看一下这几个normalization的公式:

1 BN:

TVM

2 LN:获得均值和方差是基于同一层不同神经元的数据。归一化公式相同。

3 GN: 将每个输入样本沿着通道进行分组,在每个组内进行归一化。

4 IN:对每个通道的数据进行归一化。

来看一下bacth normalization的处理代码:

Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
                            Expr moving_var, Type tdata) {
  auto ttype = tdata.as();
  CHECK(ttype);
  const auto param = attrs.as< BatchNormAttrs>();
  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon));
  Expr var_add_eps = Add(moving_var, epsilon);
  Expr sqrt_var = Sqrt(var_add_eps);
  Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);


  if (param->scale) {
    scale = Multiply(scale, gamma);
  }
  Expr neg_mean = Negative(moving_mean);
  Expr shift = Multiply(neg_mean, scale);
  if (param->center) {
    shift = Add(shift, beta);
  }


  auto ndim = ttype->shape.size();
  int axis = (param->axis <  0) ? param->axis + ndim : param->axis;
  scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
  shift = ExpandBiasToMatchAxis(shift, ndim, {axis});


  Expr out = Multiply(data, scale);
  out = Add(out, shift);
  return out;
}

可以看到就是将batch norm算子分解成最基本的加减乘除算子。

05. EliminateCommonSubexpr

顾名思义,这个pass的目的是消除公共子表达式。公共子表达式类似这种:

a=b+c

d=b+c

两个表达式具有相同的op,同时又有相同的args,而且args的顺序也一样。那么就可以用一个表达式替换。

这个pass的实现在文件src/relay/transforms/eliminate_common_subexpr.cc中。TVM定义了类CommonSubexprEliminator来处理。重载函数Rewrite_实现了对expr的遍历和重写操作。

 Expr Rewrite_(const CallNode* call, const Expr& post) final {
…
    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef< Op>(op), false)) {
      return new_expr;
    }
    if (fskip_ != nullptr && fskip_(new_expr)) {
      return new_expr;
    }


    auto it = expr_map_.find(new_call->op);
    if (it != expr_map_.end()) {
      for (const Expr& candidate_expr : it->second) {
        if (const CallNode* candidate = candidate_expr.as< CallNode>()) {
          bool is_equivalent = true;
          if (!attrs_equal(new_call->attrs, candidate->attrs)) {
            continue;
          }
          for (size_t i = 0; i <  new_call->args.size(); i++) {
            if (!new_call->args[i].same_as(candidate->args[i]) &&
                !IsEqualScalar(new_call->args[i], candidate->args[i])) {
              is_equivalent = false;
              break;
            }
          }
          if (!is_equivalent) continue;
          return GetRef(candidate);
        }
      }
    }
    expr_map_[new_call->op].push_back(new_expr);
    return new_expr;
  }

使用一个expr_map_映射记录已经遍历过的具有相同op的expr,之后每次遇到相同的op都会对已经记录的expr进行匹配,匹配包括attrs以及args,如果二者都一样的话,证明就是公共子表达式。

没有看过的pass

以上是实现相对简单的pass,TVM中还实现了其它很多pass,就没有一一去读代码了。以后看需要再去读吧。现在做一些罗列:

1 SimplifyExpr

简化一些表达式,具体如何进行简化需要读代码了。

2 CombineParallelConv2D

合并多分支并行的conv2d运算,理解是对多个batch的conv2d进行合并。

3 CombineParalleleDense

将多个batch的dense操作合并为一个batch_matmul操作。

4 CombineParallelBatchMatmul

对多个并行的batch_mamul再进行合并。

这几个combine操作可能是针对GPU器件的一个多数据并行性的优化。

5 FoldConstant

典型的一个常量合并优化。

6 FoldScaleAxis

包含了ForwardFoldScaleAxis和backwardFoldScaleAxis,主要是将scale参数合并到conv/dense操作的权重参数中。

7 CanonicalizeCast

官方解释是: Canonicalize cast expressions to make operator fusion more efficient。理解是对一些cast操作规范化,就是让复杂的cast操作可以更简洁。

8 CanonicalizeOps

规范化一些算子,比如bias_add能够被表示为expand_dims和broadcast_add操作。

审核编辑 黄昊宇

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分