DCA,不需训练让Llama上下文扩大48倍的方法

打印 上一主题 下一主题

主题 820|帖子 820|积分 2475

论文标题:Training-Free Long-Context Scaling of Large Language Models
论文所在:https://arxiv.org/pdf/2402.17463
最近研究Qwen2和Qwen2.5论文的时候,发现都有用到DCA(Dual Chunk Attention),transformers库里面的Qwen2Model只有RoPE和FlashAttention2的实现,没找到DCA代码。更多DCA的评估和使用教程可以去看Qwen团队的博客:https://qwenlm.github.io/zh/blog/qwen2.5-1m/
DCA讲的是在不进行训练、微调的情况下,大大增加模型的输入长度且天生质量不会低落太多。DCA的可行性显然已经被qwen-long这类长文本模型证明了。
DCA主要办理的是RoPE内相对位置编码的差值矩阵问题,实在很简单。但是看到文章把一个矩阵拆解成三个维度介绍的样子,有点好笑,想起了我研究生时期用TFIDF做文本分类的时候,天天写“类间权重”、“类内权重”这些名词的狼狈~
Abstract

大型模型在输入token数量超过其预训练长度时,天生本事会断崖式降落,但如果直接训练长序列的大模型成本非常高。因此作者提出了双块注意力(DCA),使 LLAMA2 70B 能够支持超过 10 万个token的上下文输入,而且不需要连续训练。
DCA可以与 Flash Attention 无缝集成。除此之外,DCA 在现实的长上下文任务上的性能与微调模型相当甚至更好。与专有模型相比,DCA的免训练 70B 模型达到了 gpt-3.5-16k 性能的 94%。
   【注】大模型的外推本事上,还有苏剑林提出的RoPE,旋转位置编码,巧妙地采用复数替代正弦余弦,被如今大模型广泛采用,苏神牛!
  Introduction

让LLM拥有长上下文处理本事很重要,具体体如今:


  • 长文档剖析场景
  • 保留更多更长的历史对话谈天
通过微调来进步上下文长度的近期研究结果主要有:
1)短上下文模型训练:通过在长文本序列上进一步训练本来为短上下文设计的模型,可以提升这些模型处理长上下文的本事。
2)Llama2 Long:模型通过在长文本数据和原始Llama2预训练语料库的混合上进行训练,也可以提升上下文长度。
但是这些方法最大的缺点就是需要微调,微调的难点在于模型数据集很难拿到(通常不公开),而且微调训练有成本。
因此,更多的方法是探究不微调不训练而使得模型有更长输入的算法。


  • LMinfiniteStreamingLLM等免训练方法通过选择性地保留基本的局部信息来处理扩展序列,从而在处理长文本时保持较低的Perplexity。但是失去了长距离文本的上下文信息。
   【注】也就是雷同于CNN和RNN的区别。这些方法就像是CNN,用卷积提取信息;而decoder-only与attention更多的需要全局序列依赖。
  

  • 还有一些基于LlaMa与RoPE开发出来的技能,如位置插值(Position Interpolation)、NTK感知旋转编码(NTK-Aware RoPE)。
固然这些方法可以低落训练长上下文输入模型的成本,但是作者发现,一旦输入长度超过训练长度两倍,这些方法就会导致Perplexity显著增加。
   【注】Perplexity(PPL)在NLP范畴叫做困惑度,用于衡量语言模型性能,公式为:
                                               P                               P                               L                               =                                           2                                               −                                                   1                                        N                                                                ∑                                                       i                                           =                                           1                                                      N                                                                               log                                           ⁡                                                      2                                                  P                                     (                                                   w                                        i                                                  ∣                                                   w                                        1                                                  ,                                                   w                                        2                                                  ,                                     .                                     .                                     .                                     ,                                                   w                                                       i                                           −                                           1                                                                )                                                             PPL = 2^{-\frac{1}{N} \sum_{i=1}^{N} \log_2 P(w_i | w_1, w_2, ..., w_{i-1})}                        PPL=2−N1​∑i=1N​log2​P(wi​∣w1​,w2​,...,wi−1​)
  这是个很古早的性能指标了。我记得从前学马尔可夫链的时候就看到过这个指标。
  差别于以上历史方法,DCA作出如下改进:
1)重用原始位置索引以及其嵌入向量,但是重新设计了相对位置矩阵的构造。
2)在块注意力上,设置每个块都小于训练窗口长度,并且提出intra-chunk attention、inter-chunk attention、succes-sive chunk attention,分别负责chunk内注意力、chunk间注意力、连续chunk注意力的提取。
Background

位置编码

论文在这里介绍了Transformer模型使用的正余弦位置编码和旋转位置编码。
这里就不要看论文的了。如今大模型的位置编码分为绝对位置编码相对位置编码
绝对位置编码也就是输入该token的索引(i)来获得嵌入,这种写死的方法不可以外推,实在如今已经没人在用了。
相对位置编码是在计算attention时,输入为两个token的索引(i和j),因为注意力实在就是两两token之间计算,因此可行。并且无论输入长度怎么拓展,相对位置编码都能支持。RoPE就是使用复数推导,同时具有绝对位置编码和相对位置编码的技能,被业界广泛采用。
RoPE外推不敷之处

RoPE固然是很好的位置编码技能,但是仍然有不敷之处:假设训练上下文长度为6,那么模型就从未训练过超过 6 的相对位置,当输入为10时,超过6的相对位置的结果就不会太好。

因此,PI和NTK通过减少整个位置矩阵的大小来缓解这类问题,以确保它在训练期间处于观察到的上下文长度范围内。
   【注】也就是说,RoPE在理论上办理了位置编码无穷外推的可行性,但是真要外推的时候,还是会受限制于训练窗口的大小。
  Method

   【注】DCA就是在改索引矩阵的下标,其目标就是把图1的矩阵值,改为图2©中的矩阵值。
  

如图2所示,三个图分别代表chunk内注意力、chunk间注意力、连续chunk注意力对应的相对位置编码的索引矩阵的值。
设模型预训练的长度为                                   c                              c                  c,本次请求模型给模型的输入长度为                                   l                              l                  l,chunk的长度为                                   s                              s                  s,那么我们可以把本次输入拆分为                                             l                            s                                       \frac{l}{s}                  sl​块。
在论文中,                                   c                         =                         10                              c=10                  c=10,                                   l                         =                         12                              l=12                  l=12,                                   s                         =                         6                              s=6                  s=6,并且此时键key的下标                                             P                            k                                  =                         {                         0                         ,                         1                         ,                         2                         ,                         .                         .                         .                         ,                         l                         −                         1                         }                              P_k=\{0,1,2,...,l-1\}                  Pk​={0,1,2,...,l−1},查询query的下标                                             P                            q                                  =                         {                         0                         ,                         1                         ,                         2                         ,                         .                         .                         .                         ,                         l                         −                         1                         }                              P_q=\{0,1,2,...,l-1\}                  Pq​={0,1,2,...,l−1}。
Intra-Chunk Attention


Intra-Chunk Attention 用于计算同一 chunk 中 queries 和 key 的内积。对于长度为                                    l                              l                  l 的长序列,将序列分别为为                                             l                            s                                       \frac{l}{s}                  sl​块长度为                                   s                              s                  s的chunk。如今界说新的键key的下标                                             P                            k                                       P_k                  Pk​与查询的下标                                             P                            q                                       I                               n                               t                               r                               a                                                 P_q^{Intra}                  PqIntra​:
                                                    P                               k                                      =                                       P                               q                                           I                                  n                                  t                                  r                                  a                                                 =                            {                            0                            ,                            1                            ,                            2                            ,                            .                            .                            .                            ,                            l                            −                            1                            }                                                         m                            o                            d                                                         s                            =                            {                            0                            ,                            1                            ,                            2                            ,                            3                            ,                            4                            ,                            5                            ,                            0                            ,                            1                            ,                            2                            ,                            3                            ,                            4                            ,                            5                            }                                  P_k = P_q^{Intra} = \{0,1,2,...,l-1\} \ mod \ s = \{0,1,2,3,4,5,0,1,2,3,4,5\}                     Pk​=PqIntra​={0,1,2,...,l−1} mod s={0,1,2,3,4,5,0,1,2,3,4,5}
然后计算相对距离                                   M                              M                  M:
                                         M                            [                            i                            ]                            [                            j                            ]                            =                                       P                               q                                           I                                  n                                  t                                  r                                  a                                                 [                            i                            ]                            −                                       P                               k                                      [                            j                            ]                                  M[j]=P_{q}^{Intra }-P_{k}[j]                     M[j]=PqIntra​−Pk​[j]
以是就得到了图3中的矩阵值。
Inter-Chunk Attention


为了聚合来自其他 chunk 的信息,论文引入了 Inter-Chunk Attention。同时由于因果注意力矩阵的特点(参考图1),Inter-Chunk Attention天生出来的矩阵值在块内要满足如下特点:

  • 最大值不能超过预训练输入长度                                        c                                  c                     c,超过了不就变成原始RoPE了嘛;
  • 从x轴方向延申,其值要递减,表现距离越来越近;
  • 从x轴方向延申,其值固然要递减,但是要保证尽可能比Intra-Chunk Attention 天生的值要大,否则没法体现块间长距离。
因此,论文提出让索引从能接受的最大距离                                   c                         −                         1                              c-1                  c−1开始降落,如今界说查询的下标                                             P                            q                                       I                               n                               t                               e                               r                                                 P_q^{Inter}                  PqInter​:
                                                    P                               q                                           I                                  n                                  t                                  e                                  r                                                 =                            [                                                                c                                     −                                     1                                     ,                                     c                                     −                                     1                                     ,                                     .                                     .                                     .                                     c                                     −                                     1                                              ⏟                                                      l                                  e                                  l                                  e                                  m                                  e                                  n                                  t                                  s                                                 ]                            =                            {                            9                            ,                            9                            ,                            9...                            ,                            9                            }                                  P_{q}^{Inter }=[\underbrace{c-1, c-1, ... c-1}_{l elements }] = \{9,9,9...,9\}                     PqInter​=[lelements                                                         c−1,c−1,...c−1​​]={9,9,9...,9}
然后计算相对距离                                   M                              M                  M:
                                         M                            [                            i                            ]                            [                            j                            ]                            =                                       P                               q                                           I                                  n                                  t                                  r                                  a                                                 [                            i                            ]                            −                                       P                               k                                      [                            j                            ]                            =                            c                            −                            1                            −                                       P                               k                                      [                            j                            ]                                  M[j]=P_{q}^{Intra }-P_{k}[j]=c-1-P_{k}[j]                     M[j]=PqIntra​−Pk​[j]=c−1−Pk​[j]
以是就得到了图4中的矩阵值。
Successive-Chunk Attention


我们将图3与图4直接相加,然后和终极结果图5对比,发现中心不一样的地方,就是Successive-Chunk Attention要办理的问题。这个地方实在是chunk与chunk之间的交汇处,但是相对距离这么大的话不太符合,因此界说查询的下标                                             P                            q                                       S                               u                               c                               c                                                 P_q^{Succ}                  PqSucc​:
                                                    P                               q                                           S                                  u                                  c                                  c                                                 =                                                                                                             s                                              ,                                              s                                              +                                              1                                              ,                                              .                                              .                                              .                                              ,                                              s                                              +                                              w                                              −                                              1                                                          ⏟                                                                     w                                           e                                           l                                           e                                           m                                           e                                           n                                           t                                           s                                                                ,                                     c                                     −                                     1                                     ,                                     .                                     .                                     .                                     ,                                     c                                     −                                     1                                              ⏞                                                      t                                  h                                  e                                  s                                  a                                  m                                  e                                  f                                  o                                  r                                  a                                  l                                  l                                  c                                  h                                  u                                  n                                  k                                  s                                                 ]                            =                            {                            6                            ,                            7                            ,                            8                            ,                            9                            ,                            9                            ,                            9                            ,                            6                            ,                            7                            ,                            8                            ,                            9                            ,                            9                            ,                            9                            }                                  P_{q}^{Succ }=\overbrace{\underbrace{s, s+1, ..., s+w-1}^{w elements }, c-1, ..., c-1}_{the same for all chunks }] = \{6,7,8,9,9,9,6,7,8,9,9,9\}                     PqSucc​=                                                                  s,s+1,...,s+w−1​welements,c−1,...,c−1                                             ​thesameforallchunks​]={6,7,8,9,9,9,6,7,8,9,9,9}
此中                                    w                              w                  w 表现本地窗口大小,可以直接设置为预训练长度和块大小之间的差值                                    c                         −                         s                              c-s                  c−s。然后计算相对距离                                   M                              M                  M,以是就得到了图5中的矩阵值。
   【注】以上就是DCA的原理,各个参数应该有经验值,可能就要自行探索了。
  Experiments

实验与总结略。阿里的qwen-long已经证明。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

大连全瓷种植牙齿制作中心

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表