理解KV cache的作用及优化方法

电子说

1.3w人已加入

描述

作者丨紫气东来    

在 Transformer 的 Encoder-base 的模型(如 BERT系列)中,推理和训练过程保持了高度的统一性(差异仅仅在于是否存在反向过程)。而在 Decoder-base 的生成式模型(如 GPT系列)中,推理和训练存在相当大的差异性,主要体现在推理过程具有以下3点特征:

自回归

两阶段(第一阶段输入 prompt,第二阶段输入上一个生成的token)

KV cache

以上三点实际上也是相辅相成、不可分割的,其中自回归的生成模式是根本原因,两阶段是外在的体现形式,KV cache 是优化手段。

下面将通过梳理整个推理过程,来理解 KV cache 的作用及优化方法。

一、KV cache 的由来与基本矛盾

LLM

第一阶段(prompt 输入):

LLM

LLM

LLM

LLM

KV cache 作用过程

第二阶段(token by token):

LLM

LLM

LLM

KV cache的显存占用分析

LLM

LLM

 

batch size s+n KV cache(GB) KV cache/weight
4 4096 81 0.23
16 4096 324 0.93
64 4096 1297 3.71

 

可见随着 batch size 和 长度的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。

而 LLM 的窗口长度也在不断增大,因此就出现一组主要矛盾,即:对不断增长的 LLM 的窗口长度的需要与有限的 GPU 显存之间的矛盾。因此优化 KV cache 就显得非常必要。

二、KV cache 优化的典型方法

2.1 共用 KV cache:MQA,GQA

MQA (Multi Query Attention,多查询注意力) 是多头注意力的一种变体。其主要区别在于,在 MQA 中不同的注意力头共享一个K和V的集合,每个头只单独保留了一份查询参数。因此K和V的矩阵仅有一份,这大幅度减少了显存占用,使其更高效。由于MQA改变了注意力机制的结构,因此模型通常需要从训练开始就支持 MQA 。也可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约 5% 的原始训练数据量 就可以达到不错的效果。包括 Falcon、SantaCoder、StarCoder 等在内很多模型都采用了 MQA 机制。

# Multi Head Attention
self.Wqkv = nn.Linear(     # Multi-Head Attention 的创建方法
    self.d_model,
    3 * self.d_model,     # Q、K和V 3 个矩阵, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(3, dim=2)      # 每个 tensor 都是 (1, 512, 768)

# Multi Query Attention
self.Wqkv = nn.Linear(       # Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,    # 只创建Q的头向量,所以是 1* d_model, 而K和V不再具备单独的头向量, 所以是 2 * self.head_dim
    device=device,
)
query, key, value = qkv.split(
    [self.d_model, self.head_dim, self.head_dim],    # query -> (1, 512, 768), key   -> (1, 512, 96), value -> (1, 512, 96)
    dim=2
)

LLM

MHA v.s. GQA v.s. MQA

GQA(Grouped Query Attention,分组查询注意力)是一种介于多头注意力和 MQA 之间的折中方案。它将查询头(Query Heads)分组,并在每组中共享一个键头(Key Head)和一个值头(Value Head)。表达能力与推理速度:GQA既保留了多头注意力的一定表达能力,又通过减少内存访问压力来加速推理速度。

LLM

MHA, GQA, MQA 性能比较

2.2 窗口优化

LLM

LLM

3)箭型 attention 窗口,在 LM-Infinit 中就已经被提出了,其基本原理和 StreamingLLM 是一致的。

LLM

2.3 量化与稀疏

该类方法是基于压缩的思想,通过量化与稀疏压缩 KV cache 的 显存消耗。

当前主流推理框架都在逐步支持 KV cache 量化,一个典型的案例是 lmdeploy,下图展示了其在 TurboMind 框架下 KV INT8 的支持情况。

LLM

lmdeploy 的推理特性

稀疏的方法也比较简单,其做法无外乎以下几种方式:

LLM

这里最值得一提的是 H2O。简单来说就是通过动态的评价方式来判断需要保留和废弃的KV值,其评估的算法如下所示:

LLM

结果显示,在 KV cache 稀疏到只有原来的 20% 时仍然可以保持很高的精度。

LLM

2.4 存储与计算优化

该方法的典型代表即 vLLM 的 PagedAttention,简单来说就是允许在非连续的内存空间中存储连续的 K 和 V。详情可参考笔者之前的文章,在此不予赘述

FlashDecoding 是在 FlashAttention 的基础上针对 inference 的优化主要分为三步:

长文本下将KV分成更小且方便并行的chunk

对每个chunk的KV,Q和他们进行之前一样的FlashAttention获取这个chunk的结果

对每个chunk的结果进行reduce

LLM

三、StreamingLLM:简洁高效的“无限长度”

StreamingLLM 的基本思想同样是来源于上述的窗口思想,其最大的创新在于提出了识别并保存模型固有的「注意力池」(attention sinks)锚定其推理的初始 token。下面将详细讨论其工作的原理。

3.1 精度是如何保证的?

核心的发现:Lost in the Middle 。

多个研究都发现,self-attention 的注意力比较集中于头部和尾部,对文本中段的注意力相对较弱,如下图所示:

LLM

绘制出 self-attention 的热力图也能看到这一点,由此当文本长度超过额定长度时,头部的 token 就会被遗弃掉,这就会在 softmax 阶段产生很大的问题。

LLM

LLM

LLM

3.2 “无限长度”是如何做到的?

该问实际上可以换种表述为:如何在文本长度不断增加的情况下,保证GPU显存不会溢出。由于该方案主要应用于多轮对话的场景,那么有必要回顾一下当前多轮对话生成的主流做法,概括起来就以下几点:

将用户输入与模型输出拼接,中间做必要分割;

多个轮次之间倒序排列,并拼接;

如果前边所有轮次长度之和超过最大长度,则截断到最大长度;

上述过程可以用代码描述如下:

    history = ["
[|Human|]{}
[|AI|]{}".format(x[0], x[1]) for x in history]
    history.append("
[|Human|]{}
[|AI|]".format(text))
    history_text = ""
    flag = False
    for x in history[::-1]:
        if tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(-1) <= max_length:
            history_text = x + history_text
            flag = True
        else:
            break
    if flag:
        inputs = tokenizer(prompt + history_text, return_tensors="pt")
        input_ids = inputs["input_ids"][:, -max_length:].to(device)
        torch.cuda.empty_cache()
        return input_ids, text
    else:
        return None

实际上这就是典型的滑动窗口的做法,滑窗 � 的存在保证了 GPU 的显存不会溢出,但是由于上节的讨论,会存在精度损失。

LLM

LLM

审核编辑:黄飞

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分