【FlashAttention-V4,非官方】FlashDecoding++

电子说

1.3w人已加入

描述

1. Introdcution

为了提高softmax并行性,之前方法(FlashAttention、FlashDecoding)将计算过程拆分,各自计算partial softmax结果,最后需要通过同步操作来更新partial softmax结果。例如FlashAttention每次计算partial softmax结果都会更新之前的结果,而FlashDecoding是在最后统一更新所有partial softmax结果。

本文在A100 GPU上分析了输入长度为1024的情况,这种同步partial softmax更新操作占Llama2-7B推理的注意力计算的18.8%。(本文没说是FlashAttention还是FlashDecoding的结果,个人认为FlashDecoding的同步更新代价并不大,应该远小于18.8%)

这是LLM推理加速的第一个挑战。此外,本文还提出了两个挑战:

在解码阶段,Flat GEMM操作的计算资源未得到充分利用。这是由于解码阶段是按顺序生成token(一次只生成一个token),GEMM操作趋于flat-shape,甚至batch size等1时变成了GEMV(General Matrix-Vector Multiplication),具体看论文Figure 2。当batch size较小时(e.g., 8),cublas和cutlass会将矩阵填充zeros以执行更大batchsize(e.g., 64)的GEMM,导致计算利用率不足50%。

动态输入和固定硬件配置影响了LLM推理的性能。例如,当batch size较小时,LLM推理的解码过程是memory-bounded,而当batch size较大时是compute-bounded。

针对这3个问题,本文分别提出了对应优化方法:

Asynchronized softmax with unified max value. FlashDecoding++为分块softmax计算设置了一个共享的最大值。这样可以独立计算partial softmax,无需同步更新。

Flat GEMM optimization with double buffering. FlashDecoding++只将矩阵大小填充到8,对比之前针对flat-shaped GEMM设计的为64,提高了计算利用率。论文指出,具有不同shape的flat GEMMs面临的瓶颈也不同,于是进一步利用双缓冲等技术提高kernel性能。

Heuristic dataflow with hardware resource adaption. FlashDecoding++同时考虑了动态输入和硬件配置,针对LLM推理时数据流进行动态kernel优化。

下图展示了以上3种方法的示意图:

数据

2. Backgrounds

LLM推理中的主要操作如下图所示:linear projection(①和⑤)、attention(②、③和④)和feedforward network(⑥)。为简单起见,这里忽略了position embedding、non-linear activation、mask等操作。本文将LLM推理时对Prompt的处理过程称为prefill phase,第二阶段预测过程称为decode phase。这两个阶段的算子基本一致,主要是输入数据的shape是不同的。由于decode phase一次只处理一个令牌(batch size=1,或batch size很小),因此输入矩阵是flat-shape matrices(甚至是vectors),参见下图Decode phase部分中和KV Cache拼接的红色向量。

数据

LLM推理中的另一个问题就是Softmax算子,其需要计算并存储所有全局数据,并且数据量随着数据长度成平方增长,存在内存消耗高和低并行性等问题。一般计算流程如下:

数据

3. Asynchronized Softmax with Unified Maximum Value如下

图b所示,FlashAttention和FlashDecoding对softmax操作进行了分块处理,但是块与块之间需要进行同步(主要是局部最大值)。本文发现这种同步操作的开销约为20%。因此,作者希望去除同步操作,也就是独立计算出partial softmax结果。

数据

数据

数据

数据

数据

4. Flat GEMM Optimization with Double Buffering

Decoding阶段的过程主要由GEMV(batch size=1)或flat GEMM(batch size>1)。GEMV/GEMM运算可以用M、N、K来表示,其中两个相乘矩阵的大小分别为M × K和K × N。一般LLM推理引擎利用Tensor Core使用cuBLAS和CUTLASS等库来加速。尽管Tensor Core适合处理M = 8的GEMM,但这些库为了隐藏memory latency,通常将M维度平铺到64。然而,decode phase的GEMV或flat GEMM的M通远小于64,于是填充0到64,导致计算利用率低下。

数据

数据

数据

数据

数据

数据

为了隐藏memory access latency,本文引入了double buffering技术。具体来说就是在共享内存中分配两个buffer,一个buffer用于执行当前tile的GEMM计算,同时另一个buffer则加载下一个tile GEMM所需的数据。这样计算和内存访问是重叠的,本文在 N 较大时采取这种策略,下图为示意图。

数据

5. Heuristic Dataflow with Hardware Resource Adaption

影响LLM推理性能的因素有很多:(a)动态输入。batch size和输入序列长度的变化造成了工作负载变化。(b)模型多样性。主要指模型结构和模型大小。(c)GPU能力不同。例如内存带宽、缓存大小和计算能力。(d)工程优化。

虽然这些因素构建了一个很大的搜索空间,但LLM中不同layer的同质性大大减少了算子优化的搜索空间。例如,prefill phase和decode phase中有4个GEMV/GEMM操作(K、Q、V投影、O投影、2个FFN),都可以表示为[M, K]和N x K,对应了四种[N, K]组合,如下图所示。此外,prefill phase的M与输入序列长度和batch size有关,decode phase的M只与batch size有关。

数据

本文根据不同的M, K, N选取FastGEMV、flat GEMM(本文方法)、CUTLASS。

数据

个人总结

这篇文章没有FlashAttention和FlashDecoding惊艳,个人觉得FlashDecoding的同步处理代价不大,而且本文中动态调整softmax方法也引入了判断、终止和分支跳转等操作。另一个Double Buffering就是内存优化常用的乒乓buffer,也没什么新东西。

不过话说回来,如今在tranformer架构不变的情况,LLM加速只能靠这些工程手段去优化,的确也有不错效果。还是很有价值的。

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分