浅谈字节跳动开源8比特混合精度Transformer引擎

嵌入式技术

1374人已加入

描述

近年来,Transformer 已经成为了 NLP 和 CV 等领域的主流模型,但庞大的模型参数限制了它的高效训练和推理。于是字节跳动在 2019 年 12 月和 2021 年 6 月分别推出了高效推理和训练引擎 LightSeq,大大加速了 Transformer 系列模型的训练和推理,也打通了 Transformer 从训练到推理的整个流程,极大优化了用户使用体验。最近,LightSeq 训练引擎相关论文[1],被录用难度极高的超算领域国际顶会 SC22 接收,得到了学术界的广泛认可!

字节跳动

SC22 接收论文:https://sc22.supercomputing.org/presentation/?id=pap211&sess=sess154

代码地址:https://github.com/bytedance/lightseq

如何继续提升速度?降低计算精度是比较直接的方法。2017 年以来,fp16 混合精度技术 [2] 获得了广泛应用。在对模型效果无损的前提下,将模型训练和推理的速度提升了 50% 以上。而为了维持模型效果,更低精度的方法(例如 int8)通常需要使用如下传统方案:

首先使用 fp16 混合精度将模型训练至收敛;

然后在模型计算密集型算子的权重、输入和输出位置处,插入伪量化结点,进行量化感知训练;

最后将带有伪量化结点的模型计算图转换到专用的 int8 推理引擎中,进行服务部署和模型推理。

虽然在多数任务上,上述方案可以实现模型效果无损,但还是存在以下问题:

使用方法复杂。例如要多一次量化感知训练 [4] 的过程,并且带有伪量化节点的计算图转换复杂。

训练速度慢。由于目前流行的深度学习框架不支持 int8 精度,所以量化感知训练需要插入 fp16 的伪量化结点来模拟 int8 量化,导致量化感知训练反而比 fp16 混合精度训练慢 2-3 倍。

推理部署难且加速比低。对比 fp32、fp16 等类型,int8 硬件和底层软件库优化相对滞后。例如在 NVIDIA GPU 上,int8 矩阵乘法加速受限于硬件架构和特定 shape,实际加速比远远低于理论值。

在下文中,如无特殊说明,量化都是指的 int8 精度的量化

字节跳动

针对这些问题,字节跳动推出了全新版本的 LightSeq GPU 量化训练与推理引擎。支持 Transformer 系列模型的量化训练与推理,并做到了开箱即用,用户友好。LightSeq 快准狠地实现了 int8 精度的量化训练和推理:

快:A100 多卡训练最高加速 5.2 倍,T4 单卡推理最高加速 8.9 倍。

准:训练和推理效果基本无损。

狠:相同数据量下,显存占用最高减少 68%,模型存储空间减少 75%。

总体来说,LightSeq 新版量化训练与推理引擎具有如下几个优点:

1. 丰富的支持

支持完整的 Transformer 模块和多种解码算法,支持 Transformer、BERT、GPT、BART、ViT 等多种模型结构,支持 Fairseq、Hugging Face、NeurST 等多种训练框架接入量化训练、导出模型以及量化推理,提供了丰富的样例供用户参考。

2. 卓越的性能

相比于 fp16 精度的 LightSeq 推理引擎,int8 量化还可以进一步加速最高 70%,相比于 PyTorch 推理更是达到了最高 8.9 倍的加速比。同时显存占用相比 fp16 推理引擎降低了 30% 左右,模型存储空间只需要原来的四分之一。最后经过多个任务的验证,推理效果几乎无损。

3. 便捷的使用

LightSeq 已经针对多个训练库进行了量化支持,可以一键开启量化训练,然后轻松导出为 LightSeq 支持的模型格式,最后实现量化推理。除此之外,LightSeq 还支持训练后量化,无需额外训练即可体验量化推理。

使用方法

字节跳动

如上图所示,为了最大程度减小量化带来的损失,首先需要用 fp16 精度训练一个浮点数模型,将模型效果训到最好。然后开启量化进行 finetune,得到微调过的量化模型,此时模型效果已经基本恢复到浮点数模型的水平。接着将量化模型转换为 LightSeq 支持的 PB 或者 HDF5 模型格式,最后用 LightSeq 进行量化推理。

安装方法

LightSeq 安装非常简单,只需要一行命令即可:

pip install lightseq

量化训练

LightSeq 支持 Fairseq、Hugging Face、NeurST 等训练框架的量化接入,同时也可以自定义模型并开启量化训练。以 encoder 层为例,只需要先定义浮点数模型,然后开启量化即可:

from lightseq.training import LSTransformerEncoderLayer
from lightseq.training.ops.pytorch.quantization import enable_quant

config = LSTransformerEncoderLayer.get_config(
    model="bert-base",
    max_batch_tokens=4096,
    max_seq_len=512,
    fp16=True,
    local_rank=0,
)
layer = LSTransformerEncoderLayer(config)
# 开启量化
layer.apply(enable_quant)

量化推理

LightSeq 提供了便捷的 python 推理接口,只需要三行代码即可实现快速的量化推理:

import lightseq.inference as lsi

model = lsi.QuantTransformer(pb_path, batch_size)
result = model.infer(input)

此外 LightSeq 还提供了 BERT、GPT、ViT 等模型的 python 接口,分别调用 QuantBert、QuantGpt 和 QuanVit 即可体验。

梯度通信量化

LightSeq 支持 Transformer 模型的梯度通信量化[5],使用 Fairseq 或者 Hugging Face 即可轻松开启分布式量化训练,并同时支持浮点数模型和量化模型。在构建模型后,只需要为模型注册一个 communication hook 即可开启梯度通信量化,再开始训练过程。

from lightseq.training.gradient_comm_quantization import encode_and_decode, GCQState
from torch.nn.parallel import DistributedDataParallel 


# model could be from Fairseq or Hugging Face, wrapped by DDP
model = DistributedDataParallel(model)
state =  GCQState(process_group)
# register hook
model.register_comm_hook(state=state, hook=encode_and_decode)

性能测试

LightSeq 在多个任务上测试了量化训练、量化推理和梯度通信量化的速度,并且分析了显存占用情况和量化模型的效果。

量化训练速度

字节跳动

LightSeq 在 8 张 A100 显卡上进行了训练实验,主要对比对象是 Fairseq 的 Transformer、Hugging Face 的 BERT、GPT2 和 ViT。

可以看出,四种模型结构加速趋势都是类似的,加速比都会随着数据量的增大而减小,原因有三点:

随着数据量的增大,矩阵乘法 GEMM 的占比会明显增加,因此 PyTorch QAT 增加的额外的伪量化结点时间占比会逐渐减小,最后速度会和 PyTorch fp16 无限接近。

与此同时,随着 GEMM 占比升高,LightSeq fp16 自定义算子的提速效果也逐渐减小,因此时间上也会和 PyTorch fp16 无限接近。

由于 Ampere 架构显卡上 int8 GEMM 在 shape 较小时甚至不如 fp16 GEMM 快,在大 shape 下才能稍快一点,因此随着数据量增大,LightSeq int8 也会无限接近 LightSeq fp16 的速度。

量化推理速度

字节跳动

LightSeq 在单张 T4 显卡上进行了推理实验,主要对比对象是 Hugging Face 的 Transformer、BERT、GPT2 和 ViT。

可以看出,随着输入数据量的增大,LightSeq 与 PyTorch 的差距会逐渐减小,这也是 GEMM 占比升高造成的。比较 LightSeq fp16 和 LightSeq int8,可以看出随着数据量的增大,LightSeq int8 越来越快。这是因为在 T4 显卡上,int8 GEMM 的加速会随着 shape 的增大而有明显增加。因此在 T4 显卡上进行量化推理时,输入数据量越大,加速效果越好。

字节跳动

LightSeq 还针对机器翻译多个语向和多个测试集,测试了不同 batch size 下,LightSeq int8 推理相对于 LightSeq fp16 推理的加速比,实验同样是在单张 T4 显卡上进行的,采用的模型都是标准的 Transformer-Big。

可以得到和上文中相同的结论,随着 batch size 的增大,量化推理的加速比会逐渐升高。相比于 LightSeq fp16,最高还可以再加速近 70%,这极大地缩短了线上翻译模型的推理延时。

字节跳动

最后如上图所示,为了展示自动 GEMM 调优技术的效果,LightSeq 测试对比了 A100 显卡上 Transformer 和 BERT 模型 fp16、int8 调优前和 int8 调优后的延时。可以看出调优前某些 shape 的 int8 GEMM 速度甚至比 fp16 还要慢,而调优后全面超越了 fp16。

显存占用

字节跳动

LightSeq 分析了不同 batch size 下,量化模型相对于浮点数模型显存占用的加速比。可以看出随着 batch size 的增大,量化模型的显存占用优势更明显,最高可以减少 30% 左右。而 LightSeq fp16 引擎相对于 PyTorch 模型也极大程度减少了显存占用,因此 LightSeq int8 引擎最终能够减少最多 68% 左右的显存。

量化模型效果

字节跳动

针对机器翻译多个语向和多个测试集,LightSeq 测试了量化模型推理相对于浮点数模型 BLEU 的损失,采用的模型都是标准的 Transformer-Big。

在数据量较大的语向 en2zh 上,LightSeq int8 相对 BLEU 损失较大些,最大达到了 - 0.4。而在数据量较小的语向 en2es 上,LightSeq int8 不仅没有任何效果损失,反而比浮点数模型更好。总体而言,int8 量化模型的平均 BLEU 相比浮点数模型基本无损。在 GLUE 和 SQuAD 等多个任务上,LightSeq 也验证了量化模型的效果。

梯度通信量化

字节跳动

由于在多机多卡场景下通信瓶颈更加明显,所以梯度通信量化主要应用在分布式训练场景。因此 LightSeq 在 2 机 8 卡的 A100 上进行了分布式训练的速度测试。

可以看出,梯度通信量化的训练加速效果整体上随着输入数据的增大而减弱。这主要是因为随着输入数据的增大,计算时间占比升高,梯度通信时间占比减少,梯度量化的收益也随之减小。

LightSeq 还额外增加了不同数量网卡(NIC)下的训练速度测试。可以看到使用梯度通信量化的分布式训练速度相比原始的 LightSeq fp16 有大幅度提升。

量化技术

int8 量化的加速收益主要来自如下几个方面:

GEMM 精度从 fp16 降低到 int8 后,计算时间缩短;

自定义算子采用 int8 输入输出后,数据读写时间缩短;

梯度采用 int8 存储后,多机之间通信时间缩短。

以 Transformer 模型为例,经过 LightSeq fp16 引擎加速后,自定义算子时间大大缩短,而 GEMM 时间占比提升到了 90% 左右,因此优化的重点转移到了 GEMM 提速。将 fp16 GEMM 替换为 int8 GEMM 不仅可以缩短 GEMM 时间,还可以减小前后算子的输入输出位宽,从而减小读写数据的时间。最后多机训练的瓶颈主要在梯度的通信,将梯度量化为 int8 精度可以大大加快分布式训练的速度。

量化原理

字节跳动

为了弥补量化带来的精度损失,通常需要用量化感知训练来模拟量化过程。如上图所示,量化感知训练就是将 float GEMM 的两个 float 输入分别做一遍量化和反量化(称之为伪量化结点),离散化成分段的浮点数输入,然后进行 float GEMM 运算。得到结果后再次进行量化与反量化,得到最终的浮点数结果。而量化的过程是不可导的,因此需要用 STE 方法来估计量化参数的梯度。之所以量化感知训练中需要插入伪量化结点,然后用 float GEMM 去模拟量化过程,是因为 TensorFlow 和 PyTorch 等训练框架不支持 int8 GEMM。

字节跳动

而 LightSeq 量化训练直接采用 int8 GEMM 来真实还原量化过程,因此相比传统的实现要更快,且更加节省显存。在推理的时候,同样采用离散化后的整数进行 int8 GEMM 运算,最后再反量化回浮点数结果。量化推理过程和量化训练完全一致,并且和传统的量化感知训练是完全等价的。

量化位置

字节跳动

整个量化 Transformer 的网络结构如上图所示,红色箭头表示需要加上量化和反量化结点的位置。

首先所有 int8 GEMM 的输入和输出都需要进行量化。由于 int8 GEMM 的 shape 限制,部分 GEMM(例如注意力分数的计算)仍然采用 float GEMM。此外第二层 FFN 的 GEMM 采用的是 int32 的输出,因为它的 GEMM 输入是 ReLU 激活函数的输出结果,只包含正数,非对称,因此如果采用 int8 输出的 GEMM,将无法反量化为正确的浮点数结果。

然后所有的模型权重 weight 都需要存储为 int8 类型,因此需要对 weight 做量化。而权重 bias 参数量较小,无需量化,保留 float 精度反而可以提升模型效果。

最后需要对 decoder 端的 cache 进行量化。因为在推理时,decoder 端的 cache 需要频繁进行读写,因此将 cache 量化为 int8 可以大大加快解码的速度。

量化策略

字节跳动

将一个浮点数矩阵量化为 int8 整数矩阵有很多方法,LightSeq 采用的是对称量化,即将正负数范围对称的浮点数区间等比例地映射到整数区间 [-127, 127] 上。

而实际上浮点数矩阵的数值范围通常并不对称,存在极少的离群值。如果直接按照离群值的范围来量化矩阵,会影响到量化后的精度,所以需要先对矩阵进行数值截断。

LightSeq 采用 PACT 方法进行截断[6],将截断的范围当作模型可学习的参数,然后利用 STE 算法去估计参数的梯度,并进行反向传播优化。根据实践经验,权重 weight 的初始截断范围设为[-1, 1],中间结果的初始截断范围设为[-16, 16],可以在大部分任务上达到最好的效果。最后经过截断范围和其他模型参数的联合优化,量化模型的效果可以达到基本无损。

梯度通信量化

针对分布式训练场景,LightSeq 推出了梯度量化压缩技术。即对浮点精度的梯度进行 int8 量化,以减少梯度通信的时间消耗,从而加速训练,这就是梯度通信量化(GCQ)。

字节跳动

如上图所示,梯度通信量化的主要流程如下:

计算每张卡上各自梯度的截断范围;

对截断范围执行 all-reduce max 操作;

每张卡使用统一的截断范围对各自梯度进行 int8 量化;

对 int8 梯度执行 all-reduce sum 操作;

每张卡对 all-reduce 后的梯度进行反量化,还原为浮点数梯度,并进行参数更新。

为了解决 int8 梯度在 all-reduce 过程中溢出的问题,LightSeq 首先将每张卡上的浮点数梯度除以卡数,再使用除之前的截断范围进行量化,最后进行 all-reduce 操作。这样每张卡上量化后的 int8 整数 all-reduce 完就不会溢出,但是单卡实际用于量化的比特数也因此而减少,所以目前方案在 2 机 8 卡效果几乎无损,但随着卡数的上涨,训练效果会有所下降。以 en2de 和 en2fr 翻译任务为例,在 4 机 8 卡上进行分布式量化训练,BLEU 值分别会下降 0.4 和 1.5 左右。未来 LightSeq 将会持续探索更好的方法来解决这一问题。

通用技术

除了上一章节中提到的量化技术以外,此次更新 LightSeq 还提出了几种通用的优化技术,不仅可以应用在量化模型中,也适用于其它所有精度模型的训练与推理。

算子融合

字节跳动

上图是 encoder 模块量化训练的计算图,LightSeq 将两次 GEMM 运算之间的所有操作融合成一个算子[7],减少了 kernel 调用的次数,因此减少了总的计算时间。

图中黄色矩形表示 int8 GEMM,绿色矩形表示 float GEMM。这里采用 float GEMM 是由于 shape 的限制,不适合使用 int8 GEMM 加速。红色箭头表示流动数据的类型是 int8,绿色箭头表示第二层 FFN 的 GEMM 输出是 int32 数据类型。int8 GEMM 输入输出的量化与反量化操作都被融合到了前后 kernel 里,这不仅可以减少数据搬运,还可以减小显存占用。

字节跳动

在推理时,LightSeq 还针对 decoder 做了优化。如上图所示,在计算 self-attention 时,注意力得分的维度是(batch size, 1, sequence length)。因此在计算 value 乘积时,可以不采用 GEMM 运算,而直接手写加权求和的算子,从而将图中虚线框中的计算融合成一个 kernel。

自动显存管理

字节跳动

模型量化引入了更复杂的张量类型和张量依赖关系,这给显存管理带来新的挑战。为此,LightSeq 设计了新的显存管理机制。如上图所示,主要包括以下过程:

训练启动前,根据每个算子的拓扑依赖关系,自动计算每个张量的生命周期及显存空间大小。其中,包含动态维度的张量按照此维度的最大量进行计算,例如机器翻译任务中的最大句长和最大 batch 句子数量。这些最大量在训练前已被指定;

张量确定生命周期和大小后,分析显存复用关系。其中,无生命周期重合的张量可以共用一片显存空间,所有显存空间都是无数据类型的,可以被分配到任意数据类型的张量上;

根据张量显存复用关系,申请多段显存空间,为每个张量分配实际的显存起止地址。

张量显存复用的分析,LightSeq 借鉴了论文 [3] 中提出的 Greedy by Size for Offset Calculation 方法,做了三个改进:

支持了整个训练过程的显存复用(forward/backward);

不同数据类型能做到显存复用(int8/fp16/fp32);

在多段显存空间上容纳所有张量,而非一段非常大的显存空间,这样能有效提升显存利用率。

自动 GEMM 调优

LightSeq 的 int8 GEMM 采用了 NVIDIA 的 cuBLASLt 库,这也是目前 NVIDIA 显卡上最为高效的矩阵运算库。但是输入数据的 shape 或者显卡不同的话,GEMM 所采用的最优配置(例如数据排布、GEMM 算法等等)也可能不同,因此需要进行自动选取。LightSeq 采取的自动调优方案如下:

在多种型号显卡上(例如 T4 和 A100)进行不同 shape 的 GEMM 最优配置搜索,并将结果保存到配置文件中,用户只需要下载即可;

模型初始化时,加载对应型号显卡的配置文件,解析并保存到键值对为 (shape, 最优配置) 的字典中。如果没有对应型号显卡的配置文件,或者没有需要的 GEMM shape,那么用户可以选择自己搜索并保存,或者直接使用默认配置;

模型前向或后向计算时,根据输入的 shape 在字典中寻找最优配置,然后进行 GEMM 计算。如果没有找到对应的 shape,那么直接采用默认的配置。

未来工作

未来 LightSeq 还将继续探索移动端的低精度量化、反向传播中梯度的量化、大模型量化等方向。

引用

[1] Wang, Xiaohui, et al. "LightSeq2: Accelerated training for transformer-based models on gpus." arXiv preprint arXiv:2110.05722 (2021).

[2] Micikevicius, Paulius, et al. "Mixed precision training." arXiv preprint arXiv:1710.03740 (2017).

[3] Pisarchyk, Yury, and Juhyun Lee. "Efficient memory management for deep neural net inference." arXiv preprint arXiv:2001.03288 (2020).

[4] Jacob, Benoit, et al. "Quantization and training of neural networks for efficient integer-arithmetic-only inference." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.

[5] Alistarh, Dan, et al. "QSGD: Communication-efficient SGD via gradient quantization and encoding." Advances in neural information processing systems 30 (2017).

[6] Choi, Jungwook, et al. "Pact: Parameterized clipping activation for quantized neural networks." arXiv preprint arXiv:1805.06085 (2018).

[7] Wang, Xiaohui, et al. "LightSeq: A high performance inference library for transformers." arXiv preprint arXiv:2010.13887 (2020).

编辑:黄飞

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分