电子说
作者丨紫气东来
在 Transformer 的 Encoder-base 的模型(如 BERT系列)中,推理和训练过程保持了高度的统一性(差异仅仅在于是否存在反向过程)。而在 Decoder-base 的生成式模型(如 GPT系列)中,推理和训练存在相当大的差异性,主要体现在推理过程具有以下3点特征:
自回归
两阶段(第一阶段输入 prompt,第二阶段输入上一个生成的token)
KV cache
以上三点实际上也是相辅相成、不可分割的,其中自回归的生成模式是根本原因,两阶段是外在的体现形式,KV cache 是优化手段。
下面将通过梳理整个推理过程,来理解 KV cache 的作用及优化方法。
一、KV cache 的由来与基本矛盾
第一阶段(prompt 输入):
KV cache 作用过程
第二阶段(token by token):
KV cache的显存占用分析
batch size | s+n | KV cache(GB) | KV cache/weight |
---|---|---|---|
4 | 4096 | 81 | 0.23 |
16 | 4096 | 324 | 0.93 |
64 | 4096 | 1297 | 3.71 |
可见随着 batch size 和 长度的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。
而 LLM 的窗口长度也在不断增大,因此就出现一组主要矛盾,即:对不断增长的 LLM 的窗口长度的需要与有限的 GPU 显存之间的矛盾。因此优化 KV cache 就显得非常必要。
二、KV cache 优化的典型方法
2.1 共用 KV cache:MQA,GQA
MQA (Multi Query Attention,多查询注意力) 是多头注意力的一种变体。其主要区别在于,在 MQA 中不同的注意力头共享一个K和V的集合,每个头只单独保留了一份查询参数。因此K和V的矩阵仅有一份,这大幅度减少了显存占用,使其更高效。由于MQA改变了注意力机制的结构,因此模型通常需要从训练开始就支持 MQA 。也可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约 5% 的原始训练数据量 就可以达到不错的效果。包括 Falcon、SantaCoder、StarCoder 等在内很多模型都采用了 MQA 机制。
# Multi Head Attention self.Wqkv = nn.Linear( # Multi-Head Attention 的创建方法 self.d_model, 3 * self.d_model, # Q、K和V 3 个矩阵, 所以是 3 * d_model device=device ) query, key, value = qkv.chunk(3, dim=2) # 每个 tensor 都是 (1, 512, 768) # Multi Query Attention self.Wqkv = nn.Linear( # Multi-Query Attention 的创建方法 d_model, d_model + 2 * self.head_dim, # 只创建Q的头向量,所以是 1* d_model, 而K和V不再具备单独的头向量, 所以是 2 * self.head_dim device=device, ) query, key, value = qkv.split( [self.d_model, self.head_dim, self.head_dim], # query -> (1, 512, 768), key -> (1, 512, 96), value -> (1, 512, 96) dim=2 )
MHA v.s. GQA v.s. MQA
GQA(Grouped Query Attention,分组查询注意力)是一种介于多头注意力和 MQA 之间的折中方案。它将查询头(Query Heads)分组,并在每组中共享一个键头(Key Head)和一个值头(Value Head)。表达能力与推理速度:GQA既保留了多头注意力的一定表达能力,又通过减少内存访问压力来加速推理速度。
MHA, GQA, MQA 性能比较
2.2 窗口优化
3)箭型 attention 窗口,在 LM-Infinit 中就已经被提出了,其基本原理和 StreamingLLM 是一致的。
2.3 量化与稀疏
该类方法是基于压缩的思想,通过量化与稀疏压缩 KV cache 的 显存消耗。
当前主流推理框架都在逐步支持 KV cache 量化,一个典型的案例是 lmdeploy,下图展示了其在 TurboMind 框架下 KV INT8 的支持情况。
lmdeploy 的推理特性
稀疏的方法也比较简单,其做法无外乎以下几种方式:
这里最值得一提的是 H2O。简单来说就是通过动态的评价方式来判断需要保留和废弃的KV值,其评估的算法如下所示:
结果显示,在 KV cache 稀疏到只有原来的 20% 时仍然可以保持很高的精度。
2.4 存储与计算优化
该方法的典型代表即 vLLM 的 PagedAttention,简单来说就是允许在非连续的内存空间中存储连续的 K 和 V。详情可参考笔者之前的文章,在此不予赘述
FlashDecoding 是在 FlashAttention 的基础上针对 inference 的优化主要分为三步:
长文本下将KV分成更小且方便并行的chunk
对每个chunk的KV,Q和他们进行之前一样的FlashAttention获取这个chunk的结果
对每个chunk的结果进行reduce
三、StreamingLLM:简洁高效的“无限长度”
StreamingLLM 的基本思想同样是来源于上述的窗口思想,其最大的创新在于提出了识别并保存模型固有的「注意力池」(attention sinks)锚定其推理的初始 token。下面将详细讨论其工作的原理。
3.1 精度是如何保证的?
核心的发现:Lost in the Middle 。
多个研究都发现,self-attention 的注意力比较集中于头部和尾部,对文本中段的注意力相对较弱,如下图所示:
绘制出 self-attention 的热力图也能看到这一点,由此当文本长度超过额定长度时,头部的 token 就会被遗弃掉,这就会在 softmax 阶段产生很大的问题。
3.2 “无限长度”是如何做到的?
该问实际上可以换种表述为:如何在文本长度不断增加的情况下,保证GPU显存不会溢出。由于该方案主要应用于多轮对话的场景,那么有必要回顾一下当前多轮对话生成的主流做法,概括起来就以下几点:
将用户输入与模型输出拼接,中间做必要分割;
多个轮次之间倒序排列,并拼接;
如果前边所有轮次长度之和超过最大长度,则截断到最大长度;
上述过程可以用代码描述如下:
history = [" [|Human|]{} [|AI|]{}".format(x[0], x[1]) for x in history] history.append(" [|Human|]{} [|AI|]".format(text)) history_text = "" flag = False for x in history[::-1]: if tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(-1) <= max_length: history_text = x + history_text flag = True else: break if flag: inputs = tokenizer(prompt + history_text, return_tensors="pt") input_ids = inputs["input_ids"][:, -max_length:].to(device) torch.cuda.empty_cache() return input_ids, text else: return None
实际上这就是典型的滑动窗口的做法,滑窗 � 的存在保证了 GPU 的显存不会溢出,但是由于上节的讨论,会存在精度损失。
审核编辑:黄飞
全部0条评论
快来发表一下你的评论吧 !