基于CKKS的非交互式安全Transformer推理实现

打印 上一主题 下一主题

主题 984|帖子 984|积分 2952

Secure Transformer Inference Made Non-interactive

本文介绍了如何使用CKKS来盘算transformer推理的每个部门。同时给出了一系列优化算法。重要涉及到的盘算算法有以下几种: 密文的压缩与分解技术、SIMD槽折叠技术、Sgn()、QuickSum、QuickMax、密文-明文矩阵相乘、密文-密文矩阵相乘法、Softmax算法、归一化、GELU函数、Argmax函数等等。此中密文的压缩与分解技术,和SIMD槽折叠技术是本文的核心创新算法。
Abstract

随着ChatGPT的遍及,安全transformer推理已经成为一个突出了研究主题。已有的解决方法通常是交互式的,涉及到客户端和服务端之间大量的通信负载和交互轮次。
本文提出NEXUS,这是第一个用于安全transformer推理的非交互式协议,此中客户端仅须要提交一个加密输入,然后等待来自服务器的加密结果即可。NEXUS的核心是两个创新的技术:SIMD密文压缩和分解技术,以及SIMD槽折叠技术。此外,同24年的另外一个解决方案相比,本方法到达了2.8倍的加速,且淘汰了368.6倍的带宽斲丧。
1 Introduction

Transformers,例如GPT和BERT,已经彻底改变了AI领域。Transformer擅长于广泛领域的应用,比如语言翻译,内容生成以及问题回答。然而这些应用总是涉及到敏感数据,从而导致越来越多地关于用户隐私的担心。例,OpenAI开发的ChatGPT作为一种在线推理服务,以及为开发人员提供的远程API,此中使用者通过提交prompts大概消息可以很容易地访问这些服务。只管这些方法是方便的,但是由于使用者提交的数据可能包含敏感信息,故而造成了严重的隐私风险。
Secure inference是一种两方暗码协议,该协议使模型推理以如下方式处理惩罚运行,即服务器S不会了解到关于客户C提交的输入的任何信息,且C不会了解到关于S的模型的任何信息,仅仅能得到最终的推理结果。
该协议大多被计划于安全CNNs推                                   [                         2                         ,                         27                         ,                         30                         ,                         36                         ]                              [2,27,30,36]                  [2,27,30,36],近来的许多工作也支持基于Transformer的模型                                   [                         10                         ,                         24                         ,                         26                         ,                         35                         ,                         38                         ,                         40                         ]                              [10,24,26,35,38,40]                  [10,24,26,35,38,40]​,值得注意的是,这些安全Transformer模型大多都是交互式的,因此会导致巨大的通信开销和交互轮次,这里我们必须夸大非交互式安全Transformer推理的重要性。
本文贡献:

本文中,我们提出了NEXUS,第一个secure transformer inference的非交互协议。通过NEXUS,C使用RNS-CKKS加密输入,S对FHE加密数据执行transformer。CKKS的SIMD技术被应用于批处理惩罚                                   N                         =                                   2                            15                                       N=2^{15}                  N=215个数据,多项式近似可以用于处理惩罚非线性函数,比如GELU,softmax,层归一化和argmax。
NEXUS不须要对模型进行任何重训练与微调,且为了进步NEXUS的效率,我们提出了两种新颖的且底子的技术。


  • SIMD密文压缩与分解:该技术可以将2N个SIMD密文压缩为一个密文,然后可以使用4N个密文—明文乘法和替换将其解压返来。该技术可以大大淘汰客户端和服务器之间传输的密文数量,而不会为后续盘算带来任何额外的开销。



  • SIMD槽折叠:在所有SIMD槽中盘算关联函数f(),例如sum和max。结果值会自动的添补SIMD密文的槽,答应将其应用于原始密文的每个槽。
本文贡献总结如下:


  • secure transformer inference的第一个非交互协议
  • 用于密文打包的SIMD密文压缩与分解技术
  • SIMD槽折叠技术,以高效操作SIMD密文的槽
  • 综合的实现与评估
2 Preliminaries

符号体系形貌如下:
NotationDescriptionNotationDescriptionCclientSserver                                                  E                                  (                                  ∗                                  )                                          E(*)                           E(∗)encryption                                                  π                                  (                                  ∗                                  )                                          \pi(*)                           π(∗)encoding                                                  E                                  n                                  c                                  (                                  ∗                                  )                                          Enc(*)                           Enc(∗)encoding+encryption                                                               a                                     ~                                                      \tilde{a}                           a~FHE ciphertext                                                  R                                  o                                  t                                  L                                  (                                  ∗                                  )                                  /                                  R                                  o                                  t                                  R                                  (                                  ∗                                  )                                          RotL(*)/RotR(*)                           RotL(∗)/RotR(∗)左旋转和右旋转                                                  S                                  u                                  b                                  s                                  (                                  ∗                                  )                                          Subs(*)                           Subs(∗)替换操作                                                  S                                  g                                  n                                  (                                  ∗                                  )                                          Sgn(*)                           Sgn(∗)sign操作                                                  L                                          L                           L乘法深度                                                               N                                     ′                                                      N'                           N′CKKS的环维数                                                  N                                          N                           N                                                  N                                  =                                               N                                     ′                                              /                                  2                                          N=N'/2                           N=N′/2                                                  A                                          A                           A输入矩阵                                                  W                                          W                           W权重矩阵 2.1 安全推理和威胁模型

安全推理是一个两方暗码学协议,其可以在C和S之间进行模型推理,与此同时还可以保护两个参与方输入隐私。它的正式定义如下:
Definition 1:

针对两方参与者,此中                                   S                              S                  S持有模型                                   M                              M                  M,且                                   C                              C                  C持有输入                                   A                              A                  A的协议                                   Π                              \Pi                  Π是安全推理协议,当且仅当以下条件满足时:
(1) 正确性: 该协议的最终输出是正确的推理结果                                   M                         (                         A                         )                              M(A)                  M(A)。
(2) 安全性:
                                    V                         i                         e                                   w                            C                            Π                                            ≈                            c                                  S                         i                                   m                            C                                  (                         A                         ,                         o                         u                         t                         )                              View^{\Pi}_C\approx_c Sim_C(A,out)                  ViewCΠ​≈c​SimC​(A,out),此中                                   V                         i                         e                                   w                            C                            Π                                       View^{\Pi}_C                  ViewCΠ​表示协议                                   Π                              \Pi                  Π执行期间                                   C                              C                  C的视角,                                   o                         u                         t                              out                  out表示推理的结果。
                                    V                         i                         e                                   w                            S                            Π                                            ≈                            S                                  S                         i                                   m                            S                                  (                         M                         )                              View^{\Pi}_S\approx_S Sim_S(M)                  ViewSΠ​≈S​SimS​(M),此中                                   V                         i                         e                                   w                            S                            Π                                       View^{\Pi}_S                  ViewSΠ​表示协议                                   Π                              \Pi                  Π执行期间                                   S                              S                  S​的视角。
                                    S                         i                                   m                            ∗                                       Sim_*                  Sim∗​可以明确为理想状态下希望实体                                   ∗                              *                  ∗可以得到的信息。
假设                                   C                              C                  C或                                   S                              S                  S为半诚实对手,其在遵守协议规范的同时也尽可能的在执行过程中手机额外的信息。且假设对手在盘算上是有限的。
2.2 Transformer

这里简单介绍一下Transformerd。
图1是transformer的布局与工作流程。它将一个表示为矩阵的嵌入通报给注意层和前馈神经网络,末了根据最终对数最大值输出一个选择向量,且,LayerNorm层被应用于每个块之后。

   transformer的布局和工作流程  Attention:

使用三个矩阵(                                             W                            Q                                  ∈                                   R                                       n                               ×                               k                                            ,                                   W                            K                                  ∈                                   R                                       n                               ×                               k                                            ,                                   W                            V                                  ∈                                   R                                       n                               ×                               k                                                 W_Q\in\mathbb{R}^{n\times k},W_K\in\mathbb{R}^{n\times k},W_V\in\mathbb{R}^{n\times k}                  WQ​∈Rn×k,WK​∈Rn×k,WV​∈Rn×k)乘嵌入矩阵                                   A                         ∈                                   R                                       m                               ×                               n                                                 A\in \mathbb{R}^{m\times n}                  A∈Rm×n,生成一个query矩阵                                   Q                         =                         A                         ⋅                                   W                            Q                                       Q = A·W_Q                  Q=A⋅WQ​,一个key矩阵                                   K                         =                         A                         ⋅                                   W                            K                                       K=A·W_K                  K=A⋅WK​和一个value矩阵                                   V                         =                         A                         ⋅                                   W                            V                                       V=A·W_V                  V=A⋅WV​。即对于Attention层的单元,transformer会学习到三个权重矩阵。
attention可以被表示为:
                                         A                            t                            t                            e                            n                            t                            i                            o                            n                            (                            Q                            ,                            K                            ,                            V                            )                            =                            S                            o                            f                            t                            m                            a                            x                            (                                                   Q                                               K                                     T                                                                  k                                                 )                            ⋅                            V                                  Attention(Q,K,V) = Softmax({QK^T\over{\sqrt k}})·V                     Attention(Q,K,V)=Softmax(k                     ​QKT​)⋅V
Layer normalization

该层的输入为                                   a                         ∈                                   R                            n                                       a\in \mathbb{R}^n                  a∈Rn,均值和标准差分别为                                   μ                              \mu                  μ和                                   σ                              \sigma                  σ,则该层的输出                                   y                         ∈                                   R                            n                                       y\in\mathbb{R}^n                  y∈Rn可以表示为:
                                                    y                               i                                      =                            γ                            ⋅                                                                x                                     i                                              −                                  μ                                          σ                                      +                            β                                  y_i=\gamma·{x_i-\mu\over\sigma}+\beta                     yi​=γ⋅σxi​−μ​+β
此中,                                   γ                         ,                         β                         ∈                         R                              \gamma,\beta\in\mathbb{R}                  γ,β∈R​是两个超参数。
Feed-forward

全毗连前馈网络层包含两个线性变更以及一个GELU激活函数:
                                         F                            e                            e                            d                            F                            o                            r                            w                            a                            r                            d                            (                            X                            )                            =                            G                            E                            L                            U                            (                            X                                       W                               1                                      +                                       b                               1                                      )                            ⋅                                       W                               2                                      +                                       b                               2                                            FeedForward(X)=GELU(XW_1+b_1)·W_2+b_2                     FeedForward(X)=GELU(XW1​+b1​)⋅W2​+b2​
此中GELU函数盘算如下:
                                         G                            E                            L                            U                            (                            x                            )                            =                                       1                               2                                      x                            ⋅                            (                            1                            +                            e                            r                            f                            (                                       x                                           2                                                 )                            )                                  GELU(x)={1\over 2}x·(1+erf({x\over \sqrt 2}))                     GELU(x)=21​x⋅(1+erf(2                     ​x​))
式中,高斯偏差函数为                                   e                         r                         f                         (                         x                         )                         =                                   2                                       π                                                      ∫                            0                            x                                            e                                       −                                           t                                  2                                                       d                         t                              erf(x)={2\over\sqrt{\pi}}\int_0^xe^{-t^2}dt                  erf(x)=π                     ​2​∫0x​e−t2dt​。由于其良好的曲率和非单调性,它被用作激活函数。
Argmax

根据最终对数最大值输出一个选择向量
可以看到,只要我们可以大概使用FHE实现各个层的盘算,就可以实现一个安全Transformer。
2.3 Fully Homomorphic Encryption

FHE可以对加密数据执行恣意操作,故FHE是使得我们构建非交互式安全transformer推理得重要工具。RNS-CKKS属于级全同态加密,其可以支持L级深度的乘法。RNS-CKKS的明文和密文均是多项式环                                             R                            Q                                  =                                   Z                            Q                                  [                         X                         ]                         /                         (                                   X                                       N                               ′                                            +                         1                         )                              R_Q=\Z_Q[X]/(X^{N'}+1)                  RQ​=ZQ​[X]/(XN′+1)上的元素。此中                                   Q                         =                                   Π                                       i                               =                               0                                      L                                            q                            i                                       Q=\Pi^L_{i=0}q_i                  Q=Πi=0L​qi​,且                                             q                            i                                       q_i                  qi​​之间互素。若密文的级别变得太低,则可以运行自举操作来革新密文到高的级别,以答应更多的盘算。
简单地说,自举即使用自同构                                             R                                       q                               0                                            ≅                                   R                                       q                               0                                            ×                                   R                                       q                               1                                            ×                         .                         .                         .                         ×                                   R                                       q                               L                                                 R_{q_0}\cong R_{q_0}\times R_{q_1}\times ... \times R_{q_L}                  Rq0​​≅Rq0​​×Rq1​​×...×RqL​​,来将密文模从                                             q                            0                                       q_0                  q0​提拔到                                             q                            L                                       q_L                  qL​,以及对密文同态评估解密电路。若自举自己斲丧K个级别,则革新后的密文支持                                   L                         −                         K                              L-K                  L−K个级深度的盘算。
RNS-CKKS支持SIMD操作,其可以加密向量                                   a                         ∈                                   R                            N                                       a\in \R^N                  a∈RN到一个密文中,且批处理惩罚这些加密元素,而不引入其他操作。为了以SIMD格式加密,首先使用编码算法                                   π                         (                         ∗                         )                              \pi(*)                  π(∗)将向量                                   a                              a                  a编码为一个                                             R                            Q                                       R_Q                  RQ​上的多项式,然后使用加密算法                                   E                         (                         ∗                         )                              E(*)                  E(∗)​加密该多项式。
在整篇文章中,我们使用                                   E                         (                         ∗                         )                              E(*)                  E(∗)表示加密多项式,使用                                   E                         n                         c                         (                         ∗                         )                              Enc(*)                  Enc(∗)表示以SIMD格式加密向量,即                                   E                         n                         c                         (                         a                         )                         =                         E                         (                         π                         (                         a                         )                         )                              Enc(a)=E(\pi(a))                  Enc(a)=E(π(a)),此中                                   a                              a                  a是一个向量。

一个特殊的FHE操作:
                                    c                                   t                            ′                                  ←                         S                         u                         b                         s                         (                         c                         t                         ,                         k                         )                              ct'\leftarrow Subs(ct,k)                  ct′←Subs(ct,k):替换操作,该操作以密文                                   c                         t                         =                         E                         (                         p                         (                         x                         )                         )                              ct=E(p(x))                  ct=E(p(x))以及一个奇整数                                   k                              k                  k作为输入,然后得到新的密文                                   c                                   t                            ′                                  =                         E                         (                         p                         (                                   x                            k                                  )                         )                              ct'=E(p(x^k))                  ct′=E(p(xk))​​​​。
这里的                                   S                         u                         b                         s                         (                         c                         t                         ,                         k                         )                              Subs(ct,k)                  Subs(ct,k)应该是一种密钥互换操作,可以形貌如下:
已知密文:                                    c                         t                         =                         (                         −                         a                         (                         x                         )                         s                         (                         x                         )                         +                         e                         (                         x                         )                         +                         p                         (                         x                         )                         ,                         a                         (                         x                         )                         )                              ct=(-a(x)s(x)+e(x)+p(x),a(x))                  ct=(−a(x)s(x)+e(x)+p(x),a(x))
将该密文进行自同构操作:                                              κ                            k                                  (                         c                         t                         )                         =                         (                         −                         a                         (                                   x                            k                                  )                         s                         (                                   x                            k                                  )                         +                         e                         (                                   x                            k                                  )                         +                         p                         (                                   x                            k                                  )                         ,                         a                         (                                   x                            k                                  )                         )                              \kappa_k(ct)=(-a(x^k)s(x^k)+e(x^k)+p(x^k),a(x^k))                  κk​(ct)=(−a(xk)s(xk)+e(xk)+p(xk),a(xk))。
然后得到用户提供的互换密钥:                                    k                         e                         y                         =                         (                         −                         a                         (                         x                         )                         s                         (                         x                         )                         +                         e                         (                         x                         )                         +                         P                         ⋅                         s                         (                                   x                            k                                  )                         ,                         a                         (                         x                         )                         )                              key = (-a(x)s(x)+e(x)+P·s(x^k),a(x))                  key=(−a(x)s(x)+e(x)+P⋅s(xk),a(x))
然后执行密钥互换操作:                                    c                                   t                            ′                                  =                         (                                   κ                            k                                  (                         c                         t                         )                         [                         0                         ]                         ,                         0                         )                         +                         (                         ⌊                                   P                                       −                               1                                            ⋅                                   κ                            k                                  (                         c                         t                         )                         [                         1                         ]                         ⋅                         k                         e                         y                         ⌉                         )                              ct'=(\kappa_k(ct)[0],0)+(\lfloor P^{-1}·\kappa_k(ct)[1]·key\rceil)                  ct′=(κk​(ct)[0],0)+(⌊P−1⋅κk​(ct)[1]⋅key⌉)
此时新的密文即                                   c                                   t                            ′                                  =                         (                         −                         a                         (                         x                         )                         s                         (                         x                         )                         +                         e                         (                         x                         )                         +                         p                         (                                   x                            k                                  )                         ,                         a                         (                         x                         )                         )                              ct'=(-a(x)s(x)+e(x)+p(x^k),a(x))                  ct′=(−a(x)s(x)+e(x)+p(xk),a(x))
注意,这里的                                   a                         (                         x                         )                         ,                         e                         (                         x                         )                              a(x),e(x)                  a(x),e(x)是变化的,也就是差别的密文中,这是差别的。

2.4 Homomorphic sign function

由于FHE仅支持线性函数,所以为了实如今FHE下对加密数据的比力,本文须要使用sign函数的多项式近似,即:
                                         s                            i                            g                            n                            (                            x                            )                            =                                       f                                           d                                  f                                                 (                                       g                                           d                                  g                                                 (                            x                            )                            )                            =                                       {                                                                                                     −                                              1                                                                                                                            (                                              −                                              1                                              ≤                                              x                                              ≤                                              −                                                               2                                                                   −                                                    α                                                                               )                                                                                                                                  0                                                                                                             (                                              x                                              =                                              0                                              )                                                                                                                                  1                                                                                                             (                                                               2                                                                   −                                                    α                                                                               ≤                                              x                                              ≤                                              1                                              )                                                                                                             sign(x)=f^{d_f}(g^{d_g}(x))=\begin{cases} -1 &(-1\leq x \leq -2^{-\alpha}) \\ 0 &(x = 0) \\ 1 &(2^{-\alpha}\leq x \leq 1) \\ \end{cases}                     sign(x)=fdf​(gdg​(x))=⎩               ⎨               ⎧​−101​(−1≤x≤−2−α)(x=0)(2−α≤x≤1)​
此中,                                   f                         (                         )                         ,                         g                         (                         )                              f(),g()                  f(),g()为两个多项式,                                             d                            f                                  ,                                   d                            g                                       d_f,d_g                  df​,dg​为这两个多项式重复的次数。注意,该多项式近似要求输入x取值范围为[-1,1]。因此,对任何输入                                   a                         ∈                         [                                   a                                       m                               i                               n                                            ,                                   a                                       m                               a                               x                                            ]                              a\in [a_{min},a_{max}]                  a∈[amin​,amax​]都须要进行归一化处理惩罚:
                                         x                            :                            =                            a                            /                            m                            a                            x                            {                            ∣                                       a                                           m                                  a                                  x                                                 ∣                            ,                            ∣                                       a                                           m                                  i                                  n                                                 ∣                            }                                  x := a/max\{|a_{max}|,|a_{min}|\}                     x:=a/max{∣amax​∣,∣amin​∣}
这里,我们使用Sgn()表示在SIMD密文上同时运行归一化与sign近似函数:
                                                    b                               ~                                      ←                            S                            g                            n                            (                                       a                               ~                                      )                            :                                       b                               i                                      =                                       f                                           d                                  f                                                 (                                       g                                           d                                  g                                                 (                                                   a                                  i                                                      m                                  a                                  x                                  {                                  ∣                                               a                                                   m                                        a                                        x                                                           ∣                                  ,                                  ∣                                               a                                                   m                                        i                                        n                                                           ∣                                  }                                                 )                            )                                                          ∀                            i                            ∈                            [                            N                            ]                                  \widetilde b\leftarrow Sgn(\widetilde a): b_i =f^{d_f}(g^{d_g}({a_i\over{max\{|a_{max}|,|a_{min}|\}}})) \ \ \forall i \in [N]                     b             ←Sgn(a             ):bi​=fdf​(gdg​(max{∣amax​∣,∣amin​∣}ai​​))  ∀i∈[N]
在本文的实现中,使用的是9次的                                   f                         (                         ∗                         )                              f(*)                  f(∗)和                                   g                         (                         ∗                         )                              g(*)                  g(∗),且计划                                   α                         =                         16                         ,                                   d                            f                                  =                         2                         ,                                   d                            g                                  =                         2                              \alpha=16,d_f=2,d_g=2                  α=16,df​=2,dg​=2​​,然后使用BSGS算法来评估多项式。
3 Basic design

本节介绍NEXUS的底子计划,即在不优化的情况下实现上述transformer的每一层盘算,在之后的章节中会对本节的算法进行优化。
3.1 Attention

3.1.1 Matrix multiplication(ciphertext-plaintext)

在Attention层的第一个MatMul步骤,我们须要盘算三个密文—明文矩阵乘法:
                                         Q                            :                            =                            A                            ⋅                                       W                               Q                                      ;                                     K                            :                            =                            A                            ⋅                                       W                               K                                      ;                                     V                            :                            =                            A                            ⋅                                       W                               V                                      ;                                  Q:=A·W_Q;\\ K:=A·W_K;\\ V:=A·W_V;                     Q:=A⋅WQ​;K:=A⋅WK​;V:=A⋅WV​;
此中A是我们的输入,                                             W                            Q                                  ,                                   W                            K                                  ,                                   W                            V                                       W_Q,W_K,W_V                  WQ​,WK​,WV​是三个给定矩阵,下面以                                   A                         ⋅                                   W                            Q                                       A·W_Q                  A⋅WQ​为例来形貌这个密文—明文矩阵乘法,该过程同样实用于                                             W                            K                                       W_K                  WK​和                                             W                            V                                       W_V                  WV​。
给定矩阵                                   A                         ∈                                   R                                       m                               ×                               n                                                 A\in \mathbb{R}^{m\times n}                  A∈Rm×n和矩阵                                             W                            Q                                  ∈                                   R                                       n                               ×                               k                                                 W_Q\in \mathbb{R}^{n\times k}                  WQ​∈Rn×k,盘算矩阵                                   Q                         :                         =                         A                         ⋅                                   W                            Q                                       Q:=A·W_Q                  Q:=A⋅WQ​。
设                                             a                                       i                               ,                               j                                            ∈                         R                              a_{i,j}\in \mathbb{R}                  ai,j​∈R表示矩阵A的第i行第j列的元素,                                             w                            j                                  ∈                                   R                            k                                       w_j\in \mathbb{R}^k                  wj​∈Rk表示矩阵                                             W                            Q                                       W_Q                  WQ​的第j行的元素向量,                                             q                            i                                  ∈                                   R                            k                                       q_i\in \mathbb{R}^k                  qi​∈Rk是矩阵                                   Q                              Q                  Q的第i行的元素向量,即:
                                                    q                               i                                      =                                       ∑                                           j                                  ∈                                  [                                  n                                  ]                                                            a                                           i                                  ,                                  j                                                 ⋅                                       w                               j                                            q_i=\sum_{j\in [n]}a_{i,j}·w_j                     qi​=j∈[n]∑​ai,j​⋅wj​
因此,上述过程可以形貌为,C将A中的每个元素                                             a                                       i                               ,                               j                                                 a_{i,j}                  ai,j​均单独加密为密文发送给S,然后S同态评估MatrixMul,一个演示的示例如下:

   图2 SIMD-based matrix multiplication  
在上述形貌中,C须要发送                                   m                         ×                         n                              m\times n                  m×n个密文给S,从某一方面来说,这种开销是比力大的,因此本文在第4节提出一种算法可以将云云范例的                                   m                         ×                         n                              m\times n                  m×n个密文压缩为                                                        m                               ×                               n                                                 N                               ′                                                 m\times n\over{N'}                  N′m×n​个密文,即一个密文中存放                                             N                            ′                                       N'                  N′个元素,随后S可以将压缩后的密文恢复为压缩前的密文形式。

3.1.2 Matrix multiplication(ciphertext-ciphertext)

颠末上述步骤后,可以得到加密的                                   (                         Q                         ,                         K                         ,                         V                         )                              (Q,K,V)                  (Q,K,V),在Attention的第二个MatMul块,S须要盘算                                   Q                         ⋅                                   K                            T                                       Q·K^T                  Q⋅KT。很明显,如今Q的每一行和                                             K                            T                                       K^T                  KT的每一列已经以SIMD的形式加密为                                   E                         n                         c                         (                         q                         )                         ,                         E                         n                         c                         (                                   k                            T                                  )                              Enc(q) , Enc(k^T)                  Enc(q),Enc(kT)。假如S可以盘算                                   E                         n                         c                         (                         q                         )                              Enc(q)                  Enc(q)和                                   E                         n                         c                         (                                   k                            T                                  )                              Enc(k^T)                  Enc(kT)的内积,则可以得到                                   Q                         ⋅                                   K                            T                                       Q·K^T                  Q⋅KT的加密结果。
由于SIMD,S可以很容易的盘算得到Enc(u),此中                                   u                         =                         [                                   u                            0                                  ,                         .                         .                         .                         ,                                   u                                       k                               −                               1                                            ]                              u=[u_0,...,u_{k-1}]                  u=[u0​,...,uk−1​]是q和                                             k                            T                                       k^T                  kT的元素级的乘法,如今为了盘算内积,S仅仅须要在SIMD下盘算                                   s                         :                         =                                   ∑                                       i                               =                               0                                                 k                               −                               1                                                      u                            i                                       s:=\sum_{i=0}^{k-1}u_i                  s:=∑i=0k−1​ui​。

为了盘算这个和,我们可以通过k-1次的旋转及加和来盘算,从而得到密文Enc([s,s,…,s]),但是本文提出了’QuickSum’算法,该算法仅仅须要logk次旋转就可到达这个目标。'QuickSum’算法在第5节介绍。

进一步,S将盘算得到每一行的的m个密文组合到单一密文中,盘算方法如下:
                                                    ∑                                           i                                  =                                  0                                                      m                                  −                                  1                                                 (                            E                            n                            c                            (                                       s                               i                                      ,                                       s                               i                                      ,                            .                            .                            .                            ,                                       s                               i                                      )                            ⋅                                       b                               i                                      )                                  \sum_{i=0}^{m-1}(Enc(s_i,s_i,...,s_i)·b_i)                     i=0∑m−1​(Enc(si​,si​,...,si​)⋅bi​)
此中                                             b                            i                                       b_i                  bi​仅在第i个槽的位置是1,其余槽均为0。
易知,输出矩阵为                                   A                         ∈                                   R                                       m                               ×                               m                                                 A\in \mathbb{R}^{m\times m}                  A∈Rm×m,此中A的每行向量以SIMD形式加密,将该结果作为Softmax的输入。
3.1.3 Softmax

Softmax函数须要被应用于A的每一行,该函数评估如下:
                                                                                                     y                                        i                                                  =                                                                  e                                           x                                           p                                           (                                                           a                                              i                                                          −                                                           a                                                               m                                                 a                                                 x                                                                          )                                                                                     ∑                                                               j                                                 =                                                 0                                                                               m                                                 −                                                 1                                                                          e                                           x                                           p                                           (                                                           a                                              j                                                          −                                                           a                                                               m                                                 a                                                 x                                                                          )                                                                                                       (1)                                                       y_i={exp(a_i-a_{max})\over{\sum_{j=0}^{m-1}exp(a_j-a_{max})}}\tag{1}                     yi​=∑j=0m−1​exp(aj​−amax​)exp(ai​−amax​)​(1)
此中                                             a                                       m                               a                               x                                            =                         m                         a                         x                         (                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                                       m                               −                               1                                            )                              a_{max}=max(a_0,...,a_{m-1})                  amax​=max(a0​,...,am−1​),从而确保指数函数的每个输入                                   (                                   a                            j                                  −                                   a                                       m                               a                               x                                            )                              (a_j-a_{max})                  (aj​−amax​)​好坏正数,保证稳固性。

本文提出了’QuickMax’算法,该算法以                                   E                         n                         c                         (                         [                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                                       m                               −                               1                                            ]                         )                              Enc([a_0,...,a_{m-1}])                  Enc([a0​,...,am−1​])为输入,并输出                                   E                         n                         c                         (                         [                                   a                                       m                               a                               x                                            ,                         .                         .                         .                         ,                                   a                                       m                               a                               x                                            ]                         )                              Enc([a_{max},...,a_{max}])                  Enc([amax​,...,amax​])​,且,该算法仅须要logm-1次Sgn操作与logm次旋转操作。该算法形貌在第5节。

给定                                   E                         n                         c                         (                         [                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                                       m                               −                               1                                            ]                         )                              Enc([a_0,...,a_{m-1}])                  Enc([a0​,...,am−1​])和                                   E                         n                         c                         (                         [                                   a                                       m                               a                               x                                            ,                         .                         .                         .                         ,                                   a                                       m                               a                               x                                            ]                         )                              Enc([a_{max},...,a_{max}])                  Enc([amax​,...,amax​])。
S进行如下步骤盘算:
                                         E                            n                            c                            (                            [                                       a                               0                               ′                                      ,                            .                            .                            .                            ,                                       a                                           m                                  −                                  1                                          ′                                      ]                            )                            =                            E                            n                            c                            (                            [                                       a                               0                                      ,                            .                            .                            .                            ,                                       a                                           m                                  −                                  1                                                 ]                            )                            −                            E                            n                            c                            (                            [                                       a                                           m                                  a                                  x                                                 ,                            .                            .                            .                            ,                                       a                                           m                                  a                                  x                                                 ]                            )                                  Enc([a'_0,...,a'_{m-1}])=Enc([a_0,...,a_{m-1}])-Enc([a_{max},...,a_{max}])                     Enc([a0′​,...,am−1′​])=Enc([a0​,...,am−1​])−Enc([amax​,...,amax​])
然后根据如下公式盘算指数函数,这里使用泰勒展开:
                                         e                            x                            p                            (                            x                            )                            ≈                            (                            1                            +                                       x                                           2                                  r                                                            )                                           2                                  r                                                 ,                            x                            ≤                            0                                  exp(x)\approx(1+{x\over{2^r}})^{2^r},x\leq 0                     exp(x)≈(1+2rx​)2r,x≤0
此中                                   r                         =                         6                              r=6                  r=6,此时平均偏差被限定在                                   1                                   0                                       −                               5                                                 10^{-5}                  10−5,即S以SIMD格式盘算指数函数:
                                         E                            n                            c                            (                                       e                               0                                      ,                            .                            .                            .                            ,                                       e                                           m                                  −                                  1                                                 )                            =                            e                            x                            p                            (                            E                            n                            c                            (                            [                                       a                               0                               ′                                      ,                            .                            .                            .                            ,                                       a                                           m                                  −                                  1                                          ′                                      ]                            )                            )                                  Enc(e_0,...,e_{m-1})=exp(Enc([a'_0,...,a'_{m-1}]))                     Enc(e0​,...,em−1​)=exp(Enc([a0′​,...,am−1′​]))
很明显,这里                                             e                            j                                  =                         e                         x                         p                         (                                   a                            j                            ′                                  )                              e_j=exp(a'_j)                  ej​=exp(aj′​)。
接下来,S应用                                   Q                         u                         i                         c                         k                         S                         u                         m                         (                         ∗                         )                              QuickSum(*)                  QuickSum(∗)算法来得到                                   E                         n                         c                         (                         [                                   ∑                                       j                               =                               0                                                 m                               −                               1                                                      e                            j                                  ,                         .                         .                         .                         ,                                   ∑                                       j                               =                               0                                                 m                               −                               1                                                      e                            j                                  ]                         )                              Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])                  Enc([∑j=0m−1​ej​,...,∑j=0m−1​ej​])​。
进一步的,S使用文献[21,24]中的Goldschmidt除法算法来盘算:
                                         E                            n                            c                            (                                       y                               0                                      ,                            .                            .                            .                            ,                                       y                                           m                                  −                                  1                                                 )                            =                                                   E                                  n                                  c                                  (                                               e                                     0                                              ,                                  .                                  .                                  .                                  ,                                               e                                                   m                                        −                                        1                                                           )                                                      E                                  n                                  c                                  (                                  [                                               ∑                                                   j                                        =                                        0                                                                m                                        −                                        1                                                                        e                                     j                                              ,                                  .                                  .                                  .                                  ,                                               ∑                                                   j                                        =                                        0                                                                m                                        −                                        1                                                                        e                                     j                                              ]                                  )                                                       Enc(y_0,...,y_{m-1})={Enc(e_0,...,e_{m-1})\over Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])}                     Enc(y0​,...,ym−1​)=Enc([∑j=0m−1​ej​,...,∑j=0m−1​ej​])Enc(e0​,...,em−1​)​
Softmax算法的详细形貌如算法1所示:

3.1.4 Matrix multiplication(ciphertext-ciphertext)

这里是Attention的末了一个MatMul块,该块的盘算原理同3.1.2节完全一致。
3.2 Layer normalization

本文的归一化表示如下(但是不太清楚这个归一化使用的是什么盘算公式):

3.3 Feed forward

前馈网络层涉及到两个矩阵乘法以及一个GELU。矩阵乘法如上文所述来盘算。GELU可以使用下述分段多项式来近似,当输入                                   x                         ∈                         [                         −                         60                         ,                         60                         ]                              x\in [-60,60]                  x∈[−60,60],则可以确保偏差在                                   1                                   0                                       −                               3                                                 10^{-3}                  10−3内。
                                         G                            E                            L                            U                            (                            x                            )                            =                            ∈                                       {                                                                                     0                                                                                                             (                                              x                                              ≤                                              −                                              4                                              )                                                                                                                                                  P                                              (                                              x                                              )                                              =                                                               ∑                                                                   i                                                    =                                                    0                                                                                    i                                                    =                                                    3                                                                                                c                                                 i                                                                               x                                                 i                                                                                                                                            (                                              −                                              4                                              <                                              x                                              ≤                                              −                                              1.95                                              )                                                                                                                                                  Q                                              (                                              x                                              )                                              =                                                               ∑                                                                   i                                                    =                                                    0                                                                                    i                                                    =                                                    6                                                                                                d                                                 i                                                                               x                                                 i                                                                                                                                            (                                              −                                              1.95                                              <                                              x                                              ≤                                              3                                              )                                                                                                                                  x                                                                                                             (                                              x                                              >                                              3                                              )                                                                                                             GELU(x)=\in \begin{cases} 0 &(x\leq -4) \\ P(x)=\sum_{i=0}^{i=3}c_ix^i &(-4<x\leq -1.95) \\ Q(x)=\sum_{i=0}^{i=6}d_ix^i &(-1.95<x\leq 3) \\ x &(x>3) \end{cases}                     GELU(x)=∈⎩               ⎨               ⎧​0P(x)=∑i=0i=3​ci​xiQ(x)=∑i=0i=6​di​xix​(x≤−4)(−4<x≤−1.95)(−1.95<x≤3)(x>3)​
首先,使用Sgn操作得到四个加密bit:                                             b                            0                                  ,                                   b                            1                                  ,                                   b                            2                                  ,                                   b                            3                                       b_0,b_1,b_2,b_3                  b0​,b1​,b2​,b3​,当且仅当输入x属于第i段时,                                             b                            i                                  =                         1                              b_i=1                  bi​=1,否则                                             b                            i                                  =                         0                              b_i=0                  bi​=0,云云,GELU(x)函数可以表示为:                                   G                         E                         L                         U                         (                         x                         )                         :                         =                                   b                            0                                  ⋅                         0                         +                                   b                            1                                  ⋅                         P                         (                         x                         )                         +                                   b                            2                                  ⋅                         Q                         (                         x                         )                         +                                   b                            3                                  ⋅                         x                              GELU(x):=b_0·0+b_1·P(x)+b_2·Q(x)+b_3·x                  GELU(x):=b0​⋅0+b1​⋅P(x)+b2​⋅Q(x)+b3​⋅x​。
完整的Secure GELU算法可以表示如下:

3.4 Argmax

transformer最终的输出应该是一个选择向量                                   E                         n                         c                         (                         [                                   b                            0                                  ,                         .                         .                         .                         ,                                   b                                       m                               −                               1                                            ]                         )                              Enc([b_0,...,b_{m-1}])                  Enc([b0​,...,bm−1​]),此中                                             b                            i                                  =                         1                                                   i                         f                                                             a                            i                                  =                         m                         a                         x                         (                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                                       m                               −                               1                                            )                              b_i=1 \ if \ a_i=max(a_0,...,a_{m-1})                  bi​=1 if ai​=max(a0​,...,am−1​),其他情况下                                             b                            i                                  =                         0                              b_i=0                  bi​=0​。
因此,本文的Secure Argmax算法如下:

3.5 Placement of bootstrapping

由于bootstrapping操作是昂贵的,因此合理的放置bootstrapping的位置是至关重要的。

   图4 Placement of bootstrapping for a BERT-base transformer  4. SIMD密文的压缩和分解

假设C想要发送N’个密文给S,且每个密文以SIMD方式加密N个相同的值,Enc([                                             a                            0                                  ,                         .                         .                         .                         ,                                   a                            0                                       a_0,...,a_0                  a0​,...,a0​]),…,Enc([                                             a                                                   N                                  ′                                          −                               1                                            ,                         .                         .                         .                         ,                                   a                                                   N                                  ′                                          −                               1                                                 a_{N'-1},...,a_{N'-1}                  aN′−1​,...,aN′−1​​​])。

SIMD密文的压缩算法
C将向量[                                             a                            0                                  ,                                   a                            1                                  ,                         .                         .                         .                         ,                                   a                                                   N                                  ′                                          −                               1                                                 a_0,a_1,...,a_{N'-1}                  a0​,a1​,...,aN′−1​]的各个元素打包到一个多项式的系数中,即:
                                         p                            (                            x                            )                            =                                       a                               0                                      +                                       a                               1                                      x                            +                                       a                               2                                                 x                               2                                      +                            .                            .                            .                            +                                       a                                                        N                                     ′                                              −                                  1                                                            x                                                        N                                     ′                                              −                                  1                                                       p(x)=a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'-1}                     p(x)=a0​+a1​x+a2​x2+...+aN′−1​xN′−1
然后将该多项式加密                                                        p                               ~                                      0                                  =                         E                         (                         p                         (                         x                         )                         )                              \widetilde p_0=E(p(x))                  p             ​0​=E(p(x))发送给S。
然后S可以对密文                                                        p                               ~                                      0                                       \widetilde p_0                  p             ​0​分解从而得到压缩前                                             N                            ′                                       N'                  N′个SIMD密文。

S分解密文                                                        p                               ~                                      0                                       \widetilde p_0                  p             ​0​过程如下:

SIMD密文的分解算法:
(1)执行                                   S                         u                         b                         s                         (                                              p                               ~                                      0                                  ,                                   N                            ′                                  +                         1                         )                              Subs(\widetilde p_0, N'+1)                  Subs(p             ​0​,N′+1)返回:
                                         E                            (                                       a                               0                                      +                                       a                               1                                                 x                                                        N                                     ′                                              +                                  1                                                 +                                       a                               2                                                 x                                           (                                               N                                     ′                                              +                                  1                                               )                                     2                                                             +                            .                            .                            .                            +                                       a                                                        N                                     ′                                              −                                  1                                                            x                                           (                                               N                                     ′                                              +                                  1                                               )                                                                  N                                           ′                                                      −                                        1                                                                          )                                     =                            E                            (                                       a                               0                                      +                                       a                               1                                      (                            −                            x                            )                            +                                       a                               2                                      (                            −                            x                                       )                               2                                      )                            +                            .                            .                            .                            +                                       a                                                        N                                     ′                                              −                                  1                                                 (                            −                            x                                       )                                                        N                                     ′                                              −                                  1                                                 )                                  E(a_0+a_1x^{N'+1}+a_2x^{(N'+1)^2}+...+a_{N'-1}x^{(N'+1)^{N'-1}}) \\ =E(a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1})                     E(a0​+a1​xN′+1+a2​x(N′+1)2+...+aN′−1​x(N′+1)N′−1)=E(a0​+a1​(−x)+a2​(−x)2)+...+aN′−1​(−x)N′−1)
注意,                                             x                                       N                               ′                                            +                         1                         ≡                         0                                                    (                         m                         o                         d                                                              x                                       N                               ′                                            +                         1                         )                              x^{N'}+1 \equiv 0 \ \ (mod \ \ x^{N'} + 1)                  xN′+1≡0  (mod  xN′+1),因此                                             x                                                   N                                  ′                                          +                               1                                            =                                   x                                       N                               ′                                            ∗                         x                         =                         −                         x                                                   (                         m                         o                         d                                                              x                                       N                               ′                                            +                         1                         )                              x^{N'+1} = x^{N'} * x = -x \ (mod \ \ x^{N'}+1)                  xN′+1=xN′∗x=−x (mod  xN′+1)​,这里的                                   N                         ’                              N’                  N’也就是分圆环的次数。
(2)执行                                                        p                               ~                                      0                                  +                         S                         u                         b                         s                         (                                              p                               ~                                      0                                  ,                                   N                            ′                                  +                         1                         )                              \widetilde p_0+Subs(\widetilde p_0,N'+1)                  p             ​0​+Subs(p             ​0​,N′+1)​操作,移除p(x)的所有奇数项。
                                                    a                               0                                      +                                       a                               1                                      x                            +                                       a                               2                                                 x                               2                                      +                            .                            .                            .                            +                                       a                                                        N                                     ′                                              −                                  1                                                            x                                                        N                                     ′                                              +                                  1                                                          +                                       a                               0                                      +                                       a                               1                                      (                            −                            x                            )                            +                                       a                               2                                      (                            −                            x                                       )                               2                                      )                            +                            .                            .                            .                            +                                       a                                                        N                                     ′                                              −                                  1                                                 (                            −                            x                                       )                                                        N                                     ′                                              −                                  1                                                          =                                       a                               0                                      +                            0                            x                            +                                       a                               2                                                 x                               2                                      +                            .                            .                            .                            +                                       a                                                        N                                     ′                                              −                                  2                                                            x                                                        N                                     ′                                              −                                  2                                                 +                            0                                       x                                                        N                                     ′                                              −                                  1                                                       a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'+1} \\+ a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1}\\= a_0+0x+a_2x^2+...+a_{N'-2}x^{N'-2}+0x^{N'-1}                     a0​+a1​x+a2​x2+...+aN′−1​xN′+1+a0​+a1​(−x)+a2​(−x)2)+...+aN′−1​(−x)N′−1=a0​+0x+a2​x2+...+aN′−2​xN′−2+0xN′−1
(3)通过                                   l                         o                         g                                   N                            ′                                       log N'                  logN′次                                   S                         u                         b                         s                         (                         )                              Subs()                  Subs()操作,S可以提取得到密文:                                   E                         (                                   a                            0                                  +                         0                                   x                            1                                  +                         0                                   x                            2                                  +                         .                         .                         .                         +                         0                                   x                                                   N                                  ′                                          −                               1                                            )                              E(a_0+0x^1+0x^2+...+0x^{N'-1})                  E(a0​+0x1+0x2+...+0xN′−1),实际上,这就是密文Enc([                                             a                            0                                  ,                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                            0                                       a_0,a_0,...,a_0                  a0​,a0​,...,a0​])。完整的操作流程如下:

雷同地,为了提取E(                                             a                            1                                  +                         0                                   x                            1                                  +                         .                         .                         .                         +                         0                                   x                                                   N                                  ′                                          −                               1                                                 a_1+0x^1+...+0x^{N'-1}                  a1​+0x1+...+0xN′−1),S应该左旋明文多项式p(x)一个单元,通过乘以                                             x                                       −                               1                                                 x^{-1}                  x−1,然后再次执行上述的提取过程。通过执行                                   N                         ‘                              N‘                  N‘次该提取过程,S可以得到向量[                                             a                            0                                  ,                                   a                            1                                  ,                         .                         .                         .                         ,                                   a                                                   N                                  ′                                          −                               1                                                 a_0,a_1,...,a_{N'-1}                  a0​,a1​,...,aN′−1​​]中每个元素的单独SIMD格式加密。
然而上述过程须要执行                                   (                                   N                            ′                                  ⋅                         l                         o                         g                                   N                            ′                                  )                              (N'·logN')                  (N′⋅logN′)次                                   S                         u                         b                         s                         (                         )                              Subs()                  Subs()操作。对比之下,本文提出一种算法,可以实现相同的目标,但是仅须要                                   2                                   N                            ′                                       2N'                  2N′次                                   S                         u                         b                         s                         (                         )                              Subs()                  Subs()操作。该算法可以简单地形貌如下:


算法5是Secure Decompression的详细形貌:

下面提供上述分解操作的理论证明:
Theorem 1:

仅有常数项的多项式的加密E(                                             a                            s                                  +                         0                                   x                            1                                  +                         .                         .                         .                         +                         0                                   x                                                   N                                  ′                                          −                               1                                                 a_s+0x^1+...+0x^{N'-1}                  as​+0x1+...+0xN′−1)是向量[                                             a                            s                                  ,                                   a                            s                                  ,                         .                         .                         .                         ,                                   a                            s                                       a_s,a_s,...,a_s                  as​,as​,...,as​]的加密Enc([                                             a                            s                                  ,                                   a                            s                                  ,                         .                         .                         .                         ,                                   a                            s                                       a_s,a_s,...,a_s                  as​,as​,...,as​​])。


4.1 Application to matrix multiplication

压缩分解技术可以天然地应用于MatrixMul,此外,基于下面的观察结果,本文进一步优化了矩阵乘法,观察到在transformer推理过程中,对于差别输入的矩阵                                   A                         ∈                                   R                                       m                               ×                               n                                                 A\in \R^{m\times n}                  A∈Rm×n须要乘以相同的矩阵                                   W                         ∈                                   R                                       n                               ×                               k                                                 W\in \R^{n\times k}                  W∈Rn×k​。
设                                   A                         =                         [                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                                       n                               −                               1                                            ]                              A = [a_0,...,a_{n-1}]                  A=[a0​,...,an−1​],此中                                             a                            i                                  ∈                                   R                            m                                       a_i\in \R^m                  ai​∈Rm表示矩阵                                   A                              A                  A的第                                   i                              i                  i行。假设S和C须要生成t个响应词,即有t个输入矩阵:
                                                    A                               0                                      =                            [                                       a                                           0                                  ,                                  0                                                 ,                                       a                                           0                                  ,                                  1                                                 ,                            .                            .                            .                            ,                                       a                                           0                                  ,                                  n                                  −                                  1                                                 ]                                                A                               1                                      =                            [                                       a                                           1                                  ,                                  0                                                 ,                                       a                                           1                                  ,                                  1                                                 ,                            .                            .                            .                            ,                                       a                                           1                                  ,                                  n                                  −                                  1                                                 ]                                     .                            .                            .                                                A                               0                                      =                            [                                       a                                           t                                  −                                  1                                  ,                                  0                                                 ,                                       a                                           t                                  −                                  1                                  ,                                  1                                                 ,                            .                            .                            .                            ,                                       a                                           t                                  −                                  1                                  ,                                  n                                  −                                  1                                                 ]                                  A_0=[a_{0,0},a_{0,1},...,a_{0,n-1}] \\ A_1=[a_{1,0},a_{1,1},...,a_{1,n-1}] \\ ... \\ A_0=[a_{t-1,0},a_{t-1,1},...,a_{t-1,n-1}]                     A0​=[a0,0​,a0,1​,...,a0,n−1​]A1​=[a1,0​,a1,1​,...,a1,n−1​]...A0​=[at−1,0​,at−1,1​,...,at−1,n−1​]
令                                             a                            i                            ′                                  =                                   [                                                                                             a                                                           0                                              ,                                              i                                                                                                                                                      a                                                           1                                              ,                                              i                                                                                                                                                      .                                           .                                           .                                                                                                                                       a                                                           t                                              −                                              1                                              ,                                              i                                                                                                       ]                                       a'_i=\left[\begin{matrix} a_{0,i} \\ a_{1,i} \\ ... \\ a_{t-1,i} \end{matrix} \right]                  ai′​=              ​a0,i​a1,i​...at−1,i​​              ​和                                             q                            j                            ′                                  :                         =                                   ∑                                       i                               =                               0                                                 n                               −                               1                                                      a                            i                            ′                                            w                                       i                               ,                               j                                                                        ∀                         j                         ∈                         [                         k                         ]                              q'_j:=\sum^{n-1}_{i=0}a'_iw_{i,j}\ \ \ \forall j\in [k]                  qj′​:=∑i=0n−1​ai′​wi,j​   ∀j∈[k],则有
                                                    Q                               ′                                      =                                       q                               0                               ′                                      ∣                            ∣                                       q                               1                               ′                                      ∣                            ∣                            .                            .                            .                            ∣                            ∣                                       q                                           k                                  −                                  1                                          ′                                      =                                       [                                                                                                                      A                                                 0                                                              W                                                                                                                                                                   A                                                 1                                                              W                                                                                                                                                  .                                              .                                              .                                                                                                                                                                   A                                                                   t                                                    −                                                    1                                                                               W                                                                                                ]                                            Q'=q'_0||q'_1||...||q'_{k-1}=\left[\begin{matrix} A_0W \\ A_1W \\ ... \\ A_{t-1}W \end{matrix} \right]                     Q′=q0′​∣∣q1′​∣∣...∣∣qk−1′​=               ​A0​WA1​W...At−1​W​               ​



预盘算阶段:
这里,我们引入一个预盘算阶段,此中S使用上述提到的密文压缩技术,将压缩后的密文                                   (                         E                         n                                   c                            S                                  (                         [                                                                       w                                                   i                                        ,                                        j                                                           ,                                               w                                                   i                                        ,                                        j                                                           ,                                  .                                  .                                  .                                  ,                                               w                                                   i                                        .                                        j                                                                   ⏟                                                 t                               ×                               m                                            ]                         )                                                    ∀                         i                         ∈                         [                         n                         ]                         ,                         j                         ∈                         [                         k                         ]                         )                              (Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k])                  (EncS​([t×m                                                      wi,j​,wi,j​,...,wi.j​​​])  ∀i∈[n],j∈[k])发送给C。注意,该传输仅只发生一次,除非模型发生改变。接下来,C对压缩的密文执行分解技术,以得到                                   E                         n                                   c                            S                                  (                         [                                                                       w                                                   i                                        ,                                        j                                                           ,                                               w                                                   i                                        ,                                        j                                                           ,                                  .                                  .                                  .                                  ,                                               w                                                   i                                        .                                        j                                                                   ⏟                                                 t                               ×                               m                                            ]                         )                                                    ∀                         i                         ∈                         [                         n                         ]                         ,                         j                         ∈                         [                         k                         ]                              Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k]                  EncS​([t×m                                                      wi,j​,wi,j​,...,wi.j​​​])  ∀i∈[n],j∈[k]。在预盘算阶段C并没有关于输入的信息,采样                                   U                         ∈                                   R                                       (                               t                               m                               )                               ×                               n                                                 U\in \R^{(tm)\times n}                  U∈R(tm)×n,然后盘算:
                                         E                            n                                       c                               S                                      (                                       v                               j                                      )                            ←                                       ∑                                           i                                  =                                  0                                                      n                                  −                                  1                                                 (                                       u                               i                                      ×                            E                            n                                       c                               S                                      (                            [                                       w                                           i                                  ,                                  j                                                 ,                            .                            .                            .                            ,                                       w                                           i                                  ,                                  j                                                 ]                            )                            )                                                           ∀                            j                            ∈                            [                            k                            ]                                  Enc_S(v_j)\leftarrow\sum^{n-1}_{i=0}(u_i\times Enc_S([w_{i,j},...,w_{i,j}]))\ \ \ \forall j\in[k]                     EncS​(vj​)←i=0∑n−1​(ui​×EncS​([wi,j​,...,wi,j​]))   ∀j∈[k]
此中                                             u                            i                                       u_i                  ui​是矩阵                                   U                              U                  U的第i列。接下来,C使用自己的密钥来加密                                   E                         n                                   c                            S                                  (                                   v                            j                                  )                              Enc_S(v_j)                  EncS​(vj​)以得到                                   E                         n                                   c                            C                                  (                         E                         n                                   c                            S                                  (                                   v                            j                                  )                         )                              Enc_C(Enc_S(v_j))                  EncC​(EncS​(vj​)),并将其发送给S。注意                                   E                         n                                   c                            S                                  (                         E                         n                                   c                            C                                  (                                   v                            j                                  )                         )                         =                         E                         n                                   c                            C                                  (                         E                         n                                   c                            S                                  (                                   v                            j                                  )                         )                              Enc_S(Enc_C(v_j))=Enc_C(Enc_S(v_j))                  EncS​(EncC​(vj​))=EncC​(EncS​(vj​)),故S可以对其进行解密,从而得到                                   E                         n                                   c                            C                                  (                                   v                            j                                  )                              Enc_C(v_j)                  EncC​(vj​)。注意,这里的                                             v                            j                                       v_j                  vj​是矩阵                                   U                         ⋅                         W                              U·W                  U⋅W的第                                   j                              j                  j列。

切换差别用户加密密钥的过程盘算如下:
给定                                   c                                   t                            S                                  =                         (                         −                         a                                   s                            S                                  +                         m                         +                         e                         )                              ct_S=(-as_S+m+e)                  ctS​=(−asS​+m+e),
使用                                             s                            C                                       s_C                  sC​加密有                                   c                                   t                                       C                               ,                               S                                            =                         (                         −                         a                                   s                            S                                  −                         a                                   s                            C                                  +                         m                         +                         e                         +                                   e                            ′                                  ,                         a                         )                              ct_{C,S}=(-as_S-as_C+m+e+e',a)                  ctC,S​=(−asS​−asC​+m+e+e′,a),
使用                                             s                            S                                       s_S                  sS​解密有:                                    c                                   t                            C                                  =                         (                         −                         a                                   s                            S                                  −                         a                                   s                            C                                  +                         m                         +                         e                         +                                   e                            ′                                  ,                         a                         )                         +                         (                         a                                   s                            S                                  ,                         0                         )                         =                         (                         −                         a                                   s                            C                                  +                         m                         +                         e                         +                                   e                            ′                                  )                              ct_C=(-as_S-as_C+m+e+e',a)+(as_S,0)=(-as_C+m+e+e')                  ctC​=(−asS​−asC​+m+e+e′,a)+(asS​,0)=(−asC​+m+e+e′).

在线处理惩罚阶段:
此时,C知道输入的信息                                             A                            ′                                  =                                   a                            0                            ′                                  ∣                         ∣                                   a                            1                            ′                                  ∣                         ∣                         .                         .                         .                         ∣                         ∣                                   a                                       n                               −                               1                                      ′                                       A'=a'_0||a'_1||...||a'_{n-1}                  A′=a0′​∣∣a1′​∣∣...∣∣an−1′​,然后C将明文                                   (                                   A                            ′                                  −                         U                         )                              (A'-U)                  (A′−U)发送给S,注意,由于S不知道U的值,故S也不清楚                                             A                            ′                                       A'                  A′的值,然后S可以盘算:
                                         (                                       A                               ′                                      −                            U                            )                            ⋅                            W                            +                            (                            E                            n                                       c                               C                                      (                                       v                               0                                      )                            ∣                            ∣                            E                            n                                       c                               C                                      (                                       v                               1                                      )                            ∣                            ∣                            .                            .                            .                            ∣                            ∣                            E                            n                                       c                               C                                      (                                       v                                           k                                  −                                  1                                                 )                            )                                     =                            (                                       A                               ′                                      W                            −                            V                            )                            +                            (                            E                            n                                       c                               C                                      (                                       v                               0                                      )                            ∣                            ∣                            E                            n                                       c                               C                                      (                                       v                               1                                      )                            ∣                            ∣                            .                            .                            .                            ∣                            ∣                            E                            n                                       c                               C                                      (                                       v                                           k                                  −                                  1                                                 )                            )                                     =                            (                            E                            n                                       c                               C                                      (                                       q                               0                               ′                                      )                            ∣                            ∣                            E                            n                                       c                               C                                      (                                       q                               1                               ′                                      )                            ∣                            ∣                            .                            .                            .                            ∣                            ∣                            E                            n                                       c                               C                                      (                                       q                                           k                                  −                                  1                                          ′                                      )                            )                                  (A'-U)·W + (Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1}))\\ =(A'W-V)+(Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1})) \\ =(Enc_C(q'_0)||Enc_C(q'_1)||...||Enc_C(q'_{k-1}))                     (A′−U)⋅W+(EncC​(v0​)∣∣EncC​(v1​)∣∣...∣∣EncC​(vk−1​))=(A′W−V)+(EncC​(v0​)∣∣EncC​(v1​)∣∣...∣∣EncC​(vk−1​))=(EncC​(q0′​)∣∣EncC​(q1′​)∣∣...∣∣EncC​(qk−1′​))
此中                                             q                            j                            ′                                       q'_j                  qj′​是矩阵                                             Q                            ′                                       Q'                  Q′的第                                   j                              j                  j列。算法6形貌了优化后的矩阵乘法细节:

可以注意到,只须要预盘算的过程交互一次,此后C可以直接向S发送明文信息                                             A                            ′                                  −                         U                              A'-U                  A′−U,而不会泄露                                             A                            ′                                       A'                  A′的信息。
5 SIMD槽折叠算法

回想矩阵Q,K,V的行向量是使用SIMD方式加密的。而上述介绍的一系列操作,如内积,Softmax,LayerNorm和Argmax等, 均涉及到使用所有槽元素盘算函数                                   f                         (                         ∗                         )                              f(*)                  f(∗),并将得到的结果放置到所有槽上。例如给定                                   E                         n                         c                         (                         [                                   a                            0                                  ,                         .                         .                         .                         ,                                   a                                       N                               −                               1                                            ]                         )                              Enc([a_0,...,a_{N-1}])                  Enc([a0​,...,aN−1​]),然后想要得到                                   E                         n                         c                         (                         [                         s                         ,                         .                         .                         .                         ,                         s                         ]                         )                              Enc([s,...,s])                  Enc([s,...,s]),此中                                   s                         =                                   ∑                                       i                               =                               0                                                 N                               −                               1                                                      a                            i                                       s=\sum^{N-1}_{i=0}a_i                  s=∑i=0N−1​ai​,此时                                   f                         (                         ∗                         )                              f(*)                  f(∗)即求和函数。
本节提供了一种通用的解决方案,只要函数                                   f                         (                         ∗                         )                              f(*)                  f(∗)满足:
                                         f                            (                            f                            (                                       a                               0                                      ,                                       a                               1                                      )                            ,                                       a                               2                                      )                            =                            f                            (                                       a                               0                                      ,                            f                            (                                       a                               1                                      ,                                       a                               2                                      )                            )                                  f(f(a_0,a_1),a_2)=f(a_0,f(a_1,a_2))                     f(f(a0​,a1​),a2​)=f(a0​,f(a1​,a2​))
算法7形貌了槽折叠算法的细节:

这里是一个简单的例子,可以看到算法7的实现流程:

5.1 QuickSum

给定                                   [                                   a                            0                                  ,                                   a                            1                                  ,                         .                         .                         .                         ,                                   a                                       n                               −                               1                                            ,                         0                         ,                         .                         .                         .                         ,                         0                         ]                              [a_0,a_1,...,a_{n-1},0,...,0]                  [a0​,a1​,...,an−1​,0,...,0],为了得到                                   [                                   ∑                                       i                               =                               0                                                 N                               −                               1                                                      a                            i                                  ,                         .                         .                         .                         ,                                   ∑                                       i                               =                               0                                                 N                               −                               1                                                      a                            i                                  ,                         0                         ,                         .                         .                         .                         ,                         0                         ]                              [\sum^{N-1}_{i=0}a_i,...,\sum^{N-1}_{i=0}a_i,0,...,0]                  [∑i=0N−1​ai​,...,∑i=0N−1​ai​,0,...,0],可以将算法7的第5行替换为                                             s                            ~                                  ←                                   s                            ~                                  +                                   a                            ~                                       \tilde{s}\leftarrow\tilde{s}+\tilde{a}                  s~←s~+a~。
5.2 QuickMax

给定                                   [                                   a                            0                                  ,                                   a                            1                                  ,                         .                         .                         .                         ,                                   a                                       n                               −                               1                                            ,                         0                         ,                         .                         .                         .                         ,                         0                         ]                              [a_0,a_1,...,a_{n-1},0,...,0]                  [a0​,a1​,...,an−1​,0,...,0],为了得到                                             a                                       m                               a                               x                                            ,                         .                         .                         .                         ,                                   a                                       m                               a                               x                                            ,                         0                         ,                         .                         .                         .                         ,                         0                              a_{max},...,a_{max},0,...,0                  amax​,...,amax​,0,...,0,此中                                             a                                       m                               a                               x                                            =                         m                         a                         x                         (                                   a                            0                                  ,                                   a                            1                                  ,                         .                         .                         .                         ,                                   a                                       n                               −                               1                                            )                              a_{max}=max(a_0,a_1,...,a_{n-1})                  amax​=max(a0​,a1​,...,an−1​),很明显                                   m                         a                         x                         (                         a                         ,                         b                         )                              max(a,b)                  max(a,b)可以表示为:
                                         m                            a                            x                            (                            a                            ,                            b                            )                            =                                                   a                                  +                                  b                                  +                                  (                                  a                                  −                                  b                                  )                                  ⋅                                  S                                  g                                  n                                  (                                  a                                  −                                  b                                  )                                          2                                            max(a,b)={a+b+(a-b)·Sgn(a-b)\over 2}                     max(a,b)=2a+b+(a−b)⋅Sgn(a−b)​
因此可以将算法7的第5行替换为:
                                                    s                               ~                                      ←                            0.5                            ⊗                            (                                       a                               ~                                      ⊕                                       s                               ~                                      ⊕                            (                                       a                               ~                                      ⊖                                       s                               ~                                      )                            ⊗                            S                            g                            n                            (                                       a                               ~                                      ⊖                                       s                               ~                                      )                            )                                  \tilde{s}\leftarrow 0.5\otimes(\tilde{a}\oplus\tilde{s}\oplus(\tilde{a}\ominus\tilde{s})\otimes Sgn(\tilde{a}\ominus\tilde{s}))                     s~←0.5⊗(a~⊕s~⊕(a~⊖s~)⊗Sgn(a~⊖s~))
6. Conclusion

本文提出了NEXUS体系,可以说是第一个不须要客户端和服务器进行交互的安全transformer推理协议。本文提出了实用于RNS-CKKS的一系列新协议,以使得服务器可以高效且精确的在加密数据上盘算transformer的每一层。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

花瓣小跑

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表