一颗AI芯片需要考量的因素有哪些?

人工智能

636人已加入

描述

量化(Quantization)在加速神经网络方面发挥了巨大作用——从 32 位到 16 位再到 8 位,甚至更快。它是如此重要,以至于谷歌目前因涉嫌侵犯 BF16 的创建者而被起诉,索赔 16 亿至 52 亿美元。所有的目光都集中在数字格式上,因为它们在过去十年中对人工智能硬件效率的提升起到了很大的作用。较低精度的数字格式有助于推倒数十亿参数模型的内存墙。

机器学习

在本文中,我们将从基本原理的基础上,从数字格式的基本原理到神经网络量化的当前技术水平进行技术探讨。我们将介绍浮点与整数、电路设计注意事项、块浮点、MSFP、微缩放格式、对数系统等。我们还将介绍推理的量化和数字格式的差异以及高精度与低精度训练方法。此外,我们将讨论面临量化和准确性损失相关挑战的模型的下一步发展。

上述都是设计一颗AI芯片需要考量的因素。

01. 矩阵乘法

任何现代机器学习模型的大部分都是矩阵乘法。在GPT-3中,每一层都使用大量矩阵乘法:例如,其中一个具体运算是(2048 x 12288)矩阵乘以(12288 x 49152)矩阵,输出(2048 x 49152)矩阵。

重要的是如何计算输出矩阵中的每个单独元素,这可以归结为两个非常大的向量的点积 - 在上面的示例中,大小为 12288。这由 12288 次乘法和 12277 次加法组成,它们累积成一个数字– 输出矩阵的单个元素。

机器学习

通常,这是通过将累加器寄存器初始化为零,然后重复地在硬件中完成的

乘以 x_i * w_i;

将其添加到累加器中;

每个周期的吞吐量均为 1。经过大约 12288 个循环后,输出矩阵的单个元素的累加完成。这种“融合乘加”运算 (FMA:fused multiply-add) 是机器学习的基本计算单元:芯片上有数千个 FMA 单元战略性地排列以有效地重用数据,因此可以并行计算输出矩阵的许多元素,以减少所需的周期数。

上图中的所有数字都需要在芯片内部的某个位以某种方式以位表示:

x_i,输入激活;

w_i,权重;

p_i,成对乘积;

整个输出完成累加之前的所有中间部分累加和;

最终输出总和;

在这个巨大的设计空间中,当今大多数机器学习量化研究都可以归结为两个目标:

足够准确地存储数千亿个权重,同时使用尽可能少的位,从容量和带宽的角度减少内存占用。这取决于用于存储权重的数字格式。

实现良好的能源和面积效率。这主要取决于用于权重和激活的数字格式;

这些目标有时是一致的,有时是不一致的——我们将深入研究这两个目标。

02. 数字格式设计目标 1:芯片效率

许多机器学习芯片计算性能的根本限制是功耗。虽然 H100 理论上可以实现 2,000 TFLOPS 的计算能力,但在此之前它会遇到功率限制 - 因此每焦耳能量的 FLOPs 是一个非常需要跟踪的指标。鉴于现代训练运行现在经常超过 1e25 次flops,我们需要极其高效的芯片,在数月内吸收(sucking)兆瓦功率,才能击败 SOTA。  

机器学习

03. 基本数字格式

首先,让我们深入了解计算中最基本的数字格式:整数。   一、以 2 为底的正整数   正整数具有明显的以 2 为底的表示形式。这些称为 UINT,即无符号整数。以下是 8 位无符号整数(也称为 UINT8,范围从 0 到 255)的一些示例。  

机器学习

  这些整数可以有任意位数,但通常仅支持以下四种格式:UINT8、UINT16、UINT32 和 UINT64。  

二、负整数(Negative integers)

负整数需要一个符号来区分正负。我们可以将一个指示符放在最高有效位中:例如0011 表示+3,1011 表示–3。这称为符号-数值(sign-magnitude)表示。以下是 INT8 的一些示例,其范围从 –128 到 127。请注意,由于第一位是符号,因此最大值实际上已从 255 减半到 127。   符号-数值很直观,但效率很低——您的电路必须实现截然不同的加法和减法算法,而这些算法又不同于没有符号位的无符号整数的电路。有趣的是,硬件设计人员可以通过使用二进制补码表示来解决这个问题,这使得可以对正数、负数和无符号数使用完全相同的进位加法器电路。所有现代 CPU 都使用二进制补码。   在 unsigned int8 中,最大数字 255 是 11111111。如果添加数字 1,255 会溢出到 00000000,即 0。在signed int8 中,最小数字是 -128,最大数字是 127。作为让 INT8 和 UINT8 共享硬件的技巧资源,-1可以用11111111表示。现在当数字加1时,它溢出到00000000,按预期表示0。同样,11111110 可以表示为-2。  

机器学习

  溢出被用作一个功能!实际上,0 到 127 被映射为正常值,128 到 255 被直接映射到 -128 到 -1。

04. 固定点(Fixed Point)

为了更进一步,我们可以在现有硬件上轻松创建新的数字格式,而无需进行修改。虽然这些都是整数,但您可以简单地想象它们是其他东西的倍数!例如,0.025 只是千分之 25,它可以存储为整数 25。现在我们只需要记住其他地方使用的所有数字都是千分之几。   新的“数字格式”可以表示从 –0.128 到 0.127 的千分之一的数字,而没有实际的逻辑变化。完整的数字仍被视为整数,然后小数点固定在右起第三位。这种策略称为定点( fixed point)。   更一般地说,这是一个有用的策略,我们将在本文中多次回顾它 - 如果您想更改可以表示的数字范围,请在某处添加比例因子。(显然,您可以用二进制来执行此操作,但十进制更容易讨论)。

05. 浮点(Floating Point)

但定点有一些缺点,特别是对于乘法。假设您需要计算一万亿乘以一万亿分——尺寸上的巨大差异就是高*动态范围*的一个例子。那么 1012和 10-12都必须用我们的数字格式来表示,所以很容易计算出你需要多少位:从 0 到 1 万亿以万亿分之一的增量计数,你需要 10^24 增量,log2(10^ 24) ~= 80 位来表示具有我们想要的精度级别的动态范围。   每个数字 80 位显然是相当浪费的。您不一定关心绝对精度,您关心相对精度。因此,尽管上述格式能够准确区分 1 万亿和 999,999,999,999.999999999999,但您通常不需要这样做。大多数时候,您关心的是相对于数字大小的误差量。   这正是科学记数法所解决的问题:在前面的示例中,我们可以将一万亿写为 1.00 * 10^12,将一万亿写为 1.00 * 10^-12,这样的存储空间要少得多。这更复杂,但可以让您在相同的上下文中表示极大和极小的数字,而无需担心。   因此,除了符号和值之外,我们现在还有一个指数。IEEE 754-1985 标准化了行业范围内以二进制存储该数据的方式,而当时使用的格式略有不同。主要有趣的格式,32 位浮点数(“float32”或“FP32”)可描述为 (1,8,23):1 个符号位、8 个指数位和 23 个尾数位。  

符号位为0表示正,1表示负;

指数位被解释为无符号整数 e,并表示比例因子 2 e-127,其值可以介于 2-126-和2127之间。更多指数位意味着更大的动态范围;

尾数位表示值 1.。更多尾数位意味着更高的相对精度;

机器学习

  其他位宽已标准化或事实上已采用,例如 FP16 (1,5,10) 和 BF16 (1,8,7)。争论的焦点是范围与精度。  

机器学习

  FP8(1,5,2 或 1,4,3)最近在 OCP 标准中标准化了一些额外的怪癖,但目前还没有定论。许多人工智能硬件公司已经实现了具有稍微优越的变体的芯片,这些变体与标准不兼容。

06. 硅效率(Silicon Efficiency)

回到硬件效率,所使用的数字格式对硅面积和所需功率有巨大影响。  

机器学习

  一、整数硅设计电路(Integer Silicon Design Circuit)

整数加法器是有史以来研究最深入的硅设计问题之一。虽然实际的实现要复杂得多,但考虑加法器的一种方法是将它们想象为根据需要将 1 一直相加并一直加到总和上,因此在某种意义上,n 位加法器正在做一定量的工作到 n   对于乘法,请回想一下小学的长乘法。我们进行 n 位乘以 1 位的乘积,然后最后将所有结果相加。在二进制中,乘以 1 位数字很简单(0 或 1)。这意味着 n 位乘法器本质上由 n 位加法器的 n 次重复组成,因此工作量与 n^2 成正比。   虽然实际实现因面积、功率和频率限制而有很大不同,但通常 1) 乘法器比加法器昂贵得多,但 2) 在低位数(8 位及以下)时,FMA 的功耗和面积成本更高以及来自加法器的更多相对贡献((n 与 n^2 缩放).

  二、浮点电路(Floating Point Circuits)

浮点单位有很大不同。相反,乘积/乘法相对简单。   如果恰好有一个输入符号为负,则符号为负,否则为正。 指数是传入指数的整数和。 尾数是传入尾数的整数积。   相比之下,总和相当复杂。   首先,计算指数差。(假设 exp1 至少与 exp2 一样大 - 如果没有,请在说明中交换它们); 将尾数 2 向下移动 (exp1 - exp2),使其与尾数 1 对齐; 向每个尾数添加隐式前导 1。如果一个符号为负,则对尾数之一执行二进制补码; 将尾数加在一起形成输出尾数; 如果发生溢出,则结果指数加1,尾数下移; 如果结果为负,则将其转换回无符号尾数并将输出符号设置为负; 对尾数进行归一化,使其具有前导 1,然后删除隐式前导 1; 适当舍入尾数(通常舍入到最接近的偶数);   值得注意的是,浮点乘法的成本甚至比整数乘法“更少”,因为尾数乘积中的位数更少,而指数的加法器比乘法器小得多,几乎无关紧要。   显然,这也是极其简化的,特别是我们没有讨论的非正规和 nan 处理占用了大量的空间。但要点是,在低位数浮点中,乘积很便宜,而累加则很昂贵。  

机器学习

  我们提到的所有部分在这里都非常明显 - 将指数相加,尾数的大型乘法器数组,根据需要移动和对齐事物,然后标准化。(从技术上讲,真正的“融合”(“fused”)乘加有点不同,但我们在这里省略了。)  

机器学习

  该图表说明了上述所有要点。有很多东西需要消化,但要点是 INT8 x INT8 的累加和累加到定点 (FX) 的成本是最便宜的,并且由乘法 (“mpy”) 主导,而使用浮点作为操作数或累加格式(通常在很大程度上)由累积成本(“alignadd”+“normacc”)主导。例如,通过使用带有“定点”累加器的 FP8 操作数而不是通常的 FP32,可以节省大量成本。   总而言之,本文和其他论文声称 FP8 FMA 将比 INT8 FMA 多占用 40-50% 的硅面积,并且能源消耗同样更高或更差的说法一直。这是大多数专用 ML 推理芯片使用 INT8 的主要原因。

07. 数字格式设计目标 2:准确性

既然整数总是更便宜,为什么我们不到处使用 INT8 和 INT16 而不是 FP8 和 FP16呢?这取决于这些格式能够如何准确地表示神经网络中实际显示的数字。   我们可以将每种数字格式视为一个查找表。例如,一个非常愚蠢的 2 位数字格式可能如下所示:  

机器学习

  显然,这组四个数字对任何事情都没有多大用处,因为它缺少太多数字 - 事实上,根本没有负数。如果表中不存在神经网络中的数字,那么您所能做的就是将其四舍五入到最近的条目,这会给神经网络带来一点误差。   那么表中理想的值集是多少?表的大小可以有多小?   例如,如果神经网络中的大多数值都接近于零(实际上也是如此),我们希望能够有很多这些数字接近于零,这样我们就可以通过牺牲准确性来获得更高的准确性。哪里没有。   在实践中,神经网络通常是正态分布或拉普拉斯分布,有时根据模型架构的确切数值,存在大量异常值。特别是,对于非常大的语言模型, 往往会出现极端异常值,这些异常值很少见,但对模型的功能很重要。

机器学习

  上图显示了 LLAMA 65B 部分权重。这看起来很像正态分布。如果将此与 FP8 和 INT8 中的数字分布进行比较,很明显浮点集中在重要的地方 - 接近零。这就是我们使用它的原因!  

机器学习

  不过,它仍然与真实分布不太匹配——每次指数递增时,它仍然有点太尖了,但比 int8 好得多。   我们可以做得更好吗?从头开始设计格式的一种方法是最小化平均绝对误差——舍入造成的平均损失量。

08. 对数系统(Log Number Systems)

例如, Nvidia在HotChips 上宣称 Log Number System 是继续扩展过去 8 位数字格式的可能途径。使用对数系统时,舍入误差通常较小,但存在许多问题,包括极其昂贵的加法器。  

机器学习

  NF4 和变体 (AF4) 是 4 位格式,假设权重遵循完全正态分布,则使用精确的查找表来最大限度地减少误差。但这种方法在面积和功耗方面非常昂贵——现在每个操作都需要查找巨大的条目表,这比任何 INT/FP 操作都要糟糕得多。   存在多种替代格式:posits、ELMA、PAL 等。这些技术声称在计算效率或表示准确性方面具有多种优势,但尚未达到商业相关规模。也许其中之一,或者尚未发表/发现的一个,将具有 INT 的成本和 FP 的表征准确性——一些人已经做出了这样的声明,或者更好。   我们个人对 Lemurian Labs PAL 最有希望,但关于其数字格式,还有很多信息尚未披露。他们声称其 16 位精度和范围比 FP16 和 BF16 更好,同时硬件也更便宜。  

机器学习

  随着我们继续扩展到过去的 8 位格式,PAL4 还声称比 HotChips 上的 Nvidia 等对数系统有更好的分布。他们的纸面声明令人惊叹,但目前还没有硬件实现该格式。  

机器学习

09. 块号格式(Block Number Formats)

一个有趣的观察是,元素的大小几乎总是与张量中附近的元素相似。当张量的元素比平常大得多时,附近的元素本质上并不重要——它们相对太小,无法在点积中看到。   我们可以利用这一点 - 我们可以在多个元素之间共享一个指数,而不是对每个数字都使用浮点指数。这节省了很多大部分冗余的指数。   这种方法已经存在了一段时间 - Nervana Flexpoint、Microsoft MSFP12、Nvidia VSQ - 直到 2023 年 OCP 的 Microscaling 才出现。   此时,存在一整套可能的格式,具有不同的权衡。微软试图量化硬件的设计空间:  

机器学习

  硬件供应商面临着一个棘手的问题,即尝试设计高度专业化的高效格式,同时又不关闭可能具有截然不同的数值分布的未来模型架构的大门。

10. 推理(Inference)

上述大部分内容都适用于推理和训练,但每种都有一些特定的复杂性。   推理对成本/功耗特别敏感,因为模型通常只训练一次,但部署到数百万客户。训练也更加复杂,有许多数值上有问题的操作(见下文)。这意味着推理芯片在采用更小、更便宜的数字格式方面通常远远领先于训练芯片,因此模型训练的格式和模型推理的格式之间可能会出现很大的差距。   有许多工具可以从一种格式适应另一种格式,这些工具属于一个范围:   一方面,训练后量化(PTQ:post-training quantization)不需要执行任何实际的训练步骤,只需根据一些简单的算法更新权重:   最简单的方法是将每个权重四舍五入到最接近的值。

The easiest is to simply round each weight to the nearest value.   LLM.int8() 将除一小部分以外的所有异常值权重转换为 INT8; GPTQ 使用有关权重矩阵的二阶信息来更好地量化; Smoothquant 进行数学上等效的变换,尝试平滑激活异常值; AWQ 使用有关激活的信息来更准确地量化最显着的权重; QuIP 对模型权重进行预处理,使其对量化不太敏感; AdaRound 将每一层的舍入分别优化为二次二元优化;   存在许多其他方法并且正在不断发布。许多“训练后”量化方法通过使用某种修改后的训练步骤或代理目标迭代优化量化模型,从而模糊了与训练的界限。这里的关键方面是,这些极大地降低了成本,但现实世界的性能损失通常比人们经常吹捧的简单基准要大。  

另一方面,量化感知训练 (QAT:quantization-aware training) 会改变精度并继续训练一段时间以使模型适应新的精度。所有量化方法都应至少部分使用此机制,以在现实世界性能中实现最小的精度损失。这直接使用常规训练过程来使模型适应量化机制,通常被认为更有效,但计算成本更高。

11. 训练(Training)

由于向后传递,训练稍微复杂一些。有 3 个 matmul——一个在前向传递中,两个在后向传递中。  

机器学习

  每个训练步骤最终都会接收权重,对各种数据进行一系列矩阵乘法,并产生新的权重。   FP8 训练更加复杂。下面是 Nvidia FP8 训练方法的稍微简化版本。  

机器学习

  这个清单的一些显着特点:   每个 matmul 都是 FP8 x FP8 并累加为 FP32 (实际上精度较低,但 Nvidia 告诉大家它是 FP32),然后量化为 FP8 以用于下一层。累加必须比 FP8 具有更高的精度,因为它涉及对同一大型累加器进行数万次连续的小更新,因此每个小更新需要很高的精度才能不向下舍入为零;   每个 FP8 权重张量都带有一个比例因子。由于每一层的范围可能显着不同,因此缩放每个张量以适应该层的范围至关重要;   权重更新(在主框之外)对精度非常敏感,并且通常保持较高的精度(通常为 FP32)。这又归结为幅度不匹配——权重更新与权重相比很小,因此再次需要精度才能使更新不向下舍入为零;   最后,训练与推理的一大区别是梯度有更多的极端异常值,这一点非常重要。可以将激活梯度(例如 SwitchBack、AQT)量化为 INT8,但权重梯度迄今为止抵制了这种努力,必须保留在 FP16 或 FP8 (1,5,2) 中。

审核编辑:黄飞

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分