搜广推校招面经五十四

打印 上一主题 下一主题

主题 972|帖子 972|积分 2916

美团保举算法

一、手撕Transformer的位置编码

1.1. 位置编码的作用

Transformer 模型没有显式的序列信息(如 RNN 的循环结构),因此必要通过位置编码(Positional Encoding)为输入序列中的每个位置添加位置信息。位置编码的作用是:


  • 提供序列位置信息:资助模型理解输入序列中元素的次序。
  • 保持唯一性和连续性:确保每个位置的位置编码是唯一的,且相邻位置的位置编码是连续的。
1.2. 位置编码公式

Transformer 利用正弦和余弦函数生成位置编码,公式如下:
                                         P                                       E                                           (                                  p                                  o                                  s                                  ,                                  2                                  i                                  )                                                 =                            sin                            ⁡                                       (                                                        p                                     o                                     s                                                           1000                                                   0                                                                       2                                              i                                                                          d                                              model                                                                                                )                                                                                     P                                       E                                           (                                  p                                  o                                  s                                  ,                                  2                                  i                                  +                                  1                                  )                                                 =                            cos                            ⁡                                       (                                                        p                                     o                                     s                                                           1000                                                   0                                                                       2                                              i                                                                          d                                              model                                                                                                )                                            PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}}\right) \\ \ \\ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}}\right)                     PE(pos,2i)​=sin(10000dmodel​2i​pos​) PE(pos,2i+1)​=cos(10000dmodel​2i​pos​)
其中:


  •                                         p                            o                            s                                  pos                     pos:位置索引(从 0 开始)。
  •                                         i                                  i                     i:维度索引(从 0 到 ( \frac{d_{\text{model}}}{2} - 1$)。
  •                                                    d                               model                                            d_{\text{model}}                     dmodel​:模型的嵌入维度。
1.3. PyTorch 实现

以下是利用 PyTorch 实现位置编码的代码:
  1. import torch
  2. import torch.nn as nn
  3. class PositionalEncoding(nn.Module):
  4.     def __init__(self, d_model, max_len=5000):
  5.         """
  6.         初始化位置编码
  7.         :param d_model: 嵌入维度
  8.         :param max_len: 最大序列长度
  9.         """
  10.         super(PositionalEncoding, self).__init__()
  11.         
  12.         # 初始化位置编码矩阵
  13.         pe = torch.zeros(max_len, d_model)
  14.         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
  15.         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))  # (d_model / 2)
  16.         
  17.         # 计算位置编码
  18.         pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置使用正弦
  19.         pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置使用余弦
  20.         
  21.         # 注册为缓冲区(不参与训练)
  22.         self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)
  23.    
  24.     def forward(self, x):
  25.         """
  26.         前向传播
  27.         :param x: 输入张量,形状为 (batch_size, seq_len, d_model)
  28.         :return: 添加位置编码后的张量,形状为 (batch_size, seq_len, d_model)
  29.         """
  30.         x = x + self.pe[:, :x.size(1)]  # 添加位置编码
  31.         return x
  32. # 示例
  33. d_model = 512  # 嵌入维度
  34. max_len = 50   # 最大序列长度
  35. batch_size = 10  # 批量大小
  36. seq_len = 20     # 序列长度
  37. # 创建位置编码层
  38. pe = PositionalEncoding(d_model, max_len)
  39. # 随机生成输入张量
  40. x = torch.randn(batch_size, seq_len, d_model)
  41. # 添加位置编码
  42. x_with_pe = pe(x)
  43. print(x_with_pe.shape)  # 输出: torch.Size([10, 20, 512])
复制代码
二、为什么 Multi-Head Attention 没有改变 QKV 计算的参数目但对结果有提升?

2.1. Multi-Head Attention 的根本原理

Multi-Head Attention 是 Transformer 模型的核心组件之一,其核心思想是通过多个注意力头(Attention Head)并行计算注意力,然后将结果拼接起来。具体步调如下:

  • 线性变换:将输入                                         Q                            、                            K                            、                            V                                  Q、K、V                     Q、K、V 分别通过线性变换生成多个头的                                                    Q                               i                                      、                                       K                               i                                      、                                       V                               i                                            Q_i、K_i、V_i                     Qi​、Ki​、Vi​ 。
  • 并行计算:每个头独立计算注意力分数。
  • 拼接和线性变换:将多个头的输出拼接起来,并通过线性变换得到最终输出。
2.2. 结果提升的原因

尽管参数目没有增长,但多头注意力对结果的提升重要来自以下几个方面:
(1)并行计算



  • 多个头可以并行计算注意力,捕捉输入序列中不同位置的不同特性。每个头可以关注不同的子空间,从而增强模型的表达能力。
(2)多视角学习



  • 每个头可以学习到不同的注意力模式(如局部依赖、全局依赖等)。通过拼接多个头的输出,模型可以综合多个视角的信息,提升泛化能力。
(3)特性多样性



  • 多头注意力可以捕捉输入序列中不同层次的特性(如语法、语义等)。这种多样性有助于模型更好地理解复杂的序列数据。
(4)计算效率



  • 固然参数目没有增长,但多头注意力通过并行计算提高了计算效率。每个头的维度减小,减少了计算复杂度。
三、Word2Vec 的原理及损失函数界说

Word2Vec 是一种用于学习词向量的模型,其核心思想是通过上下文预测目标词(Skip-gram)或通过目标词预测上下文(CBOW)。Word2Vec 的目标是将每个词映射到一个低维稠密向量空间中,使得语义相似的词在向量空间中距离较近。
(1)Skip-gram 模型



  • 目标:给定一个中心词,预测其上下文词。
  • 输入:中心词。
  • 输出:上下文词的概率分布。
(2)CBOW 模型



  • 目标:给定上下文词,预测中心词。
  • 输入:上下文词。
  • 输出:中心词的概率分布。
3.2. Word2Vec 的损失函数

Word2Vec 的损失函数通常利用 负对数似然损失(Negative Log-Likelihood Loss),具体界说如下:
(1)Skip-gram 的损失函数

对于 Skip-gram 模型,损失函数界说为:
                                         L                            =                            −                                       1                               T                                                 ∑                                           t                                  =                                  1                                          T                                                 ∑                                           −                                  c                                  ≤                                  j                                  ≤                                  c                                  ,                                  j                                  ≠                                  0                                                 log                            ⁡                            p                            (                                       w                                           t                                  +                                  j                                                 ∣                                       w                               t                                      )                                  L = -\frac{1}{T} \sum_{t=1}^{T} \sum_{-c \leq j \leq c, j \neq 0} \log p(w_{t+j} | w_t)                     L=−T1​t=1∑T​−c≤j≤c,j=0∑​logp(wt+j​∣wt​)
其中:


  •                                         T                                  T                     T:语料库中的总词数。
  •                                         c                                  c                     c:上下文窗口大小。
  •                                                    w                               t                                            w_t                     wt​:中心词。
  •                                                    w                                           t                                  +                                  j                                                       w_{t+j}                     wt+j​:上下文词。
  •                                         p                            (                                       w                                           t                                  +                                  j                                                 ∣                                       w                               t                                      )                                  p(w_{t+j} | w_t)                     p(wt+j​∣wt​):给定中心词                                                    w                               t                                            w_t                     wt​ 时,上下文词                                                    w                                           t                                  +                                  j                                                       w_{t+j}                     wt+j​ 的条件概率。
(2)CBOW 的损失函数

对于 CBOW 模型,损失函数界说为:
                                         L                            =                            −                                       1                               T                                                 ∑                                           t                                  =                                  1                                          T                                      log                            ⁡                            p                            (                                       w                               t                                      ∣                                       w                                           t                                  −                                  c                                                 ,                            …                            ,                                       w                                           t                                  −                                  1                                                 ,                                       w                                           t                                  +                                  1                                                 ,                            …                            ,                                       w                                           t                                  +                                  c                                                 )                                  L = -\frac{1}{T} \sum_{t=1}^{T} \log p(w_t | w_{t-c}, \dots, w_{t-1}, w_{t+1}, \dots, w_{t+c})                     L=−T1​t=1∑T​logp(wt​∣wt−c​,…,wt−1​,wt+1​,…,wt+c​)
其中:


  •                                         T                                  T                     T:语料库中的总词数。
  •                                         c                                  c                     c:上下文窗口大小。
  •                                                    w                               t                                            w_t                     wt​:中心词。
  •                                                    w                                           t                                  −                                  c                                                 ,                            …                            ,                                       w                                           t                                  +                                  c                                                       w_{t-c}, \dots, w_{t+c}                     wt−c​,…,wt+c​:上下文词。
  •                                         p                            (                                       w                               t                                      ∣                                       w                                           t                                  −                                  c                                                 ,                            …                            ,                                       w                                           t                                  +                                  c                                                 )                                  p(w_t | w_{t-c}, \dots, w_{t+c})                     p(wt​∣wt−c​,…,wt+c​):给定上下文词时,中心词                                                    w                               t                                            w_t                     wt​ 的条件概率。
(3)条件概率的计算

条件概率                                    p                         (                                   w                            O                                  ∣                                   w                            I                                  )                              p(w_O | w_I)                  p(wO​∣wI​) 通过 Softmax 函数计算:
                                         p                            (                                       w                               O                                      ∣                                       w                               I                                      )                            =                                                   exp                                  ⁡                                  (                                               v                                                   w                                        O                                                  T                                                           v                                                   w                                        I                                                           )                                                                   ∑                                                   w                                        =                                        1                                                  V                                              exp                                  ⁡                                  (                                               v                                     w                                     T                                                           v                                                   w                                        I                                                           )                                                       p(w_O | w_I) = \frac{\exp(v_{w_O}^T v_{w_I})}{\sum_{w=1}^{V} \exp(v_w^T v_{w_I})}                     p(wO​∣wI​)=∑w=1V​exp(vwT​vwI​​)exp(vwO​T​vwI​​)​
其中:


  •                                                    v                                           w                                  I                                                       v_{w_I}                     vwI​​:输入词                                                    w                               I                                            w_I                     wI​ 的向量表示。
  •                                                    v                                           w                                  O                                                       v_{w_O}                     vwO​​:输出词                                                    w                               O                                            w_O                     wO​ 的向量表示。
  •                                         V                                  V                     V:词汇表大小。
3. 负采样(Negative Sampling)

由于 Softmax 的计算复杂度较高(与词汇表大小                                    V                              V                  V 成正比),Word2Vec 通常利用负采样(Negative Sampling)来近似损失函数。负采样的损失函数界说为:
                                         L                            =                            −                            log                            ⁡                            σ                            (                                       v                                           w                                  O                                          T                                                 v                                           w                                  I                                                 )                            −                                       ∑                                           i                                  =                                  1                                          k                                      log                            ⁡                            σ                            (                            −                                       v                                           w                                  i                                          T                                                 v                                           w                                  I                                                 )                                  L = -\log \sigma(v_{w_O}^T v_{w_I}) - \sum_{i=1}^{k} \log \sigma(-v_{w_i}^T v_{w_I})                     L=−logσ(vwO​T​vwI​​)−i=1∑k​logσ(−vwi​T​vwI​​)
其中:


  •                                         σ                                  \sigma                     σ:Sigmoid 函数。
  •                                         k                                  k                     k:负样本的数目。
  •                                                    w                               i                                            w_i                     wi​:负样本词。
四、为什么可以通过负采样近似 Softmax?

4.1. Softmax 的计算复杂度问题

Softmax 函数的计算复杂度为                                    O                         (                         V                         )                              O(V)                  O(V),其中                                    V                              V                  V 是词汇表的大小。对于大规模词汇表(如数百万词),Softmax 的计算本钱非常高,重要体现在:


  • 计算指数:必要对每个词计算指数。
  • 归一化:必要对全部词的指数求和,然后归一化。
4.2. 负采样的根本思想

负采样(Negative Sampling)是一种近似 Softmax 的方法,通过采样少量负样本来替换全词汇表的计算。其核心思想是:


  • 正样本:目标词(实际出现在上下文中的词)。
  • 负样本:随机采样的非目标词(未出现在上下文中的词)。
  • 目标:最大化正样本的概率,最小化负样本的概率。
4.3. 负采样的数学原理

(1)Softmax 的原始情势

Softmax 的条件概率界说为:
                                         p                            (                                       w                               O                                      ∣                                       w                               I                                      )                            =                                                   exp                                  ⁡                                  (                                               v                                                   w                                        O                                                  T                                                           v                                                   w                                        I                                                           )                                                                   ∑                                                   w                                        =                                        1                                                  V                                              exp                                  ⁡                                  (                                               v                                     w                                     T                                                           v                                                   w                                        I                                                           )                                                       p(w_O | w_I) = \frac{\exp(v_{w_O}^T v_{w_I})}{\sum_{w=1}^{V} \exp(v_w^T v_{w_I})}                     p(wO​∣wI​)=∑w=1V​exp(vwT​vwI​​)exp(vwO​T​vwI​​)​
其中:


  •                                                    v                                           w                                  I                                                       v_{w_I}                     vwI​​:输入词                                                    w                               I                                            w_I                     wI​的向量表示。
  •                                                    v                                           w                                  O                                                       v_{w_O}                     vwO​​:输出词                                                    w                               O                                            w_O                     wO​ 的向量表示。
  •                                         V                                  V                     V:词汇表大小。
(2)负采样的近似情势

负采样通过采样少量负样本                                              w                            i                                       w_i                  wi​ 来近似 Softmax 的分母。具体步调如下:

  • 正样本:计算正样本的概率:
                                                  σ                               (                                           v                                               w                                     O                                              T                                                      v                                               w                                     I                                                      )                                      \sigma(v_{w_O}^T v_{w_I})                        σ(vwO​T​vwI​​)
    其中                                         σ                                  \sigma                     σ是 Sigmoid 函数。
  • 负样本:计算负样本的概率:
                                                  σ                               (                               −                                           v                                               w                                     i                                              T                                                      v                                               w                                     I                                                      )                                      \sigma(-v_{w_i}^T v_{w_I})                        σ(−vwi​T​vwI​​)
  • 损失函数:将正样本和负样本的概率结合起来,界说损失函数:
                                                  L                               =                               −                               log                               ⁡                               σ                               (                                           v                                               w                                     O                                              T                                                      v                                               w                                     I                                                      )                               −                                           ∑                                               i                                     =                                     1                                              k                                          log                               ⁡                               σ                               (                               −                                           v                                               w                                     i                                              T                                                      v                                               w                                     I                                                      )                                      L = -\log \sigma(v_{w_O}^T v_{w_I}) - \sum_{i=1}^{k} \log \sigma(-v_{w_i}^T v_{w_I})                        L=−logσ(vwO​T​vwI​​)−i=1∑k​logσ(−vwi​T​vwI​​)
    其中                                         k                                  k                     k 是负样本的数目。
(3)为什么可以近似?



  • 分母的近似:Softmax 的分母是对全部词的指数求和,计算复杂度高。负采样通过采样少量负样本,近似计算分母(但是牺牲精度)。
五、召回的评价指标



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

花瓣小跑

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表