RoPE可能是LLM时代的Resnet

描述

因为和苏神做过一段时间同事,所以2021年就知道RoPE了,当时也没太在意,因为位置编码是在为transformer类模型提供位置信息,在我实际实验中不同位置编码对最终效果差别很小。

2023年LLM大爆发,facebook开源了LLAMA模型,并且采用了RoPE,我也第一时间用上了LLAMA,那会感觉RoPE有点东西,但是还是心理觉得位置编码没那么重要

直到最近fb发了一篇文章《EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION》通过线性插值+少量微调的方式将LLAMA原始2k的模型轻松拓展到了32k,这时候我感觉到RoPE的强大之处。

进NLP群—>加入NLP交流群

通过线性插值RoPE扩张LLAMA context长度最早其实是在llamacpp项目中被人发现,有人在推理的时候直接通过线性插值将LLAMA由2k拓展到4k,性能没有下降,引起了很多人关注。fb的论文给这个发现提供了理论和实验支撑,进一步发现通过线性插值+微调可以扩展到32k长度。实现非常简单,只需要对位置编码进行线性插值,初始化的时候增加几行代码就行

def RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        max_position_embeddings = 8192

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(
            self.max_seq_len_cached,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )

        self.scale = 1 / 4
        t *= self.scale

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer(
            "cos_cached", emb.cos()[NoneNone, :, :], persistent=False
        )
        self.register_buffer(
            "sin_cached", emb.sin()[NoneNone, :, :], persistent=False
        )

这两天reddit上又出现了ntk RoPE通过引入新的插值的scale,来扩展context,甚至微调都不需要!让人震撼。实现也是极其简单

import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    a = 8 #Alpha value
    base = base * a ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)


transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

具体解释可以参考苏神自己写的文章[1]

为什么说RoPE会成为LLM时代的Resnet,首先是两者解决的问题有相似性。

Resnet解决了卷积模型变深之后梯度消失的问题,使的深度模型大放光彩。

RoPE类似的也解决了LLM context过长之后引起的上下文无法关联问题。

两者都有结构简单,方法有效的优点,这个在工程上有极大的优势,个人预感RoPE将会被大规模采用。如同当年Resnet一样。

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分