Llama 2中的Margin Loss:为何更高的Margin导致更大的Loss和梯度? ...

打印 上一主题 下一主题

主题 978|帖子 978|积分 2934

Llama 2中的Margin Loss:为何更高的Margin导致更大的Loss和梯度?

在《Llama 2: Open Foundation and Fine-Tuned Chat Models》论文中,作者在强化学习与人类反馈(RLHF)的Reward Model练习中引入了Margin Loss的概念,相较于传统的InstructGPT方法有所创新。下面有一段关键描述:
   “For instance, returning a higher margin via ‘m( r)’ will make the difference between the reward of the preferred and rejected responses smaller, resulting in a larger loss, which in turn results in larger gradients, and consequently model changes, during the policy gradient update.”
source: https://magazine.sebastianraschka.com/p/llm-training-rlhf-and-its-alternatives
  这段话涉及Margin Loss的逻辑:为什么更高的margin会导致更大的loss?为什么更大的loss会导致更大的梯度?本文将以中文博客的情势,详细分析这个过程的数学原理和直观意义,帮助你明白其中的因果关系。

1. Margin Loss的根本概念

在RLHF的Reward Model练习中,目标是让模子学会根据人类偏好对响应进行评分。对于一对响应 (                                              y                            c                                       y_c                  yc​ )(优选响应,chosen)和 (                                              y                            r                                       y_r                  yr​ )(拒绝响应,rejected),Reward Model (                                              r                            θ                                  (                         x                         ,                         y                         )                              r_\theta(x, y)                  rθ​(x,y) ) 输出标量夸奖值,要求 (                                              r                            θ                                  (                         x                         ,                                   y                            c                                  )                         >                                   r                            θ                                  (                         x                         ,                                   y                            r                                  )                              r_\theta(x, y_c) > r_\theta(x, y_r)                  rθ​(x,yc​)>rθ​(x,yr​) )。
传统损失函数

传统的InstructGPT使用基于交织熵的排名损失:这个loss是如何推导的,请参考笔者的另一篇博客:RLHF中的Reward Model是如何练习的?原理与代码实现
                                         L                            (                            θ                            )                            =                            −                            log                            ⁡                                       (                               σ                                           (                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                  −                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                  )                                          )                                            L(\theta) = -\log\left(\sigma\left(r_\theta(x, y_c) - r_\theta(x, y_r)\right)\right)                     L(θ)=−log(σ(rθ​(x,yc​)−rθ​(x,yr​)))


  • (                                         σ                            (                            z                            )                            =                                       1                                           1                                  +                                  exp                                  ⁡                                  (                                  −                                  z                                  )                                                       \sigma(z) = \frac{1}{1 + \exp(-z)}                     σ(z)=1+exp(−z)1​ ) 是sigmoid函数,将差值映射为0到1的概率。
  • (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_c) - r_\theta(x, y_r)                     rθ​(x,yc​)−rθ​(x,yr​) ) 是优选和拒绝响应的夸奖差值。
  • 损失的目标是使 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_c) - r_\theta(x, y_r)                     rθ​(x,yc​)−rθ​(x,yr​) ) 尽大概大,从而让 (                                         σ                                  \sigma                     σ ) 接近1,损失接近0。
Llama 2的Margin Loss

Llama 2在此基础上增加了margin参数 (                                    m                         (                         r                         )                              m(r)                  m(r) ):
                                         L                            (                            θ                            )                            =                            −                            log                            ⁡                                       (                               σ                                           (                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                  −                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                  −                                  m                                  (                                  r                                  )                                  )                                          )                                            L(\theta) = -\log\left(\sigma\left(r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)\right)\right)                     L(θ)=−log(σ(rθ​(x,yc​)−rθ​(x,yr​)−m(r)))


  • (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 是人类标注的偏好程度(margin label),比如“显著更好”(significantly better)对应较大的 (                                         m                            (                            r                            )                                  m(r)                     m(r) ),而“略好”(negligibly better)对应较小的 (                                         m                            (                            r                            )                                  m(r)                     m(r) )。
  • (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 是一个正值,表示优选响应比拒绝响应“应该”高出的最小夸奖差距。

2. 为什么更高的Margin导致更大的Loss?

直观明白



  • (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_c) - r_\theta(x, y_r)                     rθ​(x,yc​)−rθ​(x,yr​) ) 是模子当前猜测的夸奖差值。
  • (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 是人类期望的“理想差值”。
  • 损失函数中的 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                            −                            m                            (                            r                            )                                  r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                     rθ​(x,yc​)−rθ​(x,yr​)−m(r) ) 表示“实际差值”与“期望差值”的差距。
当 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 变大时:


  • 如果模子的猜测差值 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_c) - r_\theta(x, y_r)                     rθ​(x,yc​)−rθ​(x,yr​) ) 不变,减去一个更大的 (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 会使 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                            −                            m                            (                            r                            )                                  r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                     rθ​(x,yc​)−rθ​(x,yr​)−m(r) ) 变小(甚至大概变成负值)。
  • (                                         σ                                  \sigma                     σ ) 函数的值随之变小(因为 (                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) 是单调递增的,(                                         z                                  z                     z ) 减小则 (                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) 减小)。
  • (                                         −                            log                            ⁡                            (                            σ                            (                            z                            )                            )                                  -\log(\sigma(z))                     −log(σ(z)) ) 会变大,因为 (                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) 越小,对数的值越大,负号使损失增加。
简朴来说,更高的 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 进步了对模子的要求。如果模子的猜测差值没有达到这个更高的尺度,损失就会增大。
数学推导

设:


  •                                         z                            =                                       r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                            −                            m                            (                            r                            )                                  z = r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                     z=rθ​(x,yc​)−rθ​(x,yr​)−m(r)。
损失函数为:
                                         L                            =                            −                            log                            ⁡                            (                            σ                            (                            z                            )                            )                            =                            −                            log                            ⁡                                       (                                           1                                               1                                     +                                     exp                                     ⁡                                     (                                     −                                     z                                     )                                                      )                                            L = -\log(\sigma(z)) = -\log\left(\frac{1}{1 + \exp(-z)}\right)                     L=−log(σ(z))=−log(1+exp(−z)1​)


  • 当 (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 增加时,(                                         z                                  z                     z ) 减小。
  • (                                         exp                            ⁡                            (                            −                            z                            )                                  \exp(-z)                     exp(−z) ) 增大(因为 (                                         −                            z                                  -z                     −z ) 变大),使 (                                         1                            +                            exp                            ⁡                            (                            −                            z                            )                                  1 + \exp(-z)                     1+exp(−z) ) 增大。
  • (                                         σ                            (                            z                            )                            =                                       1                                           1                                  +                                  exp                                  ⁡                                  (                                  −                                  z                                  )                                                       \sigma(z) = \frac{1}{1 + \exp(-z)}                     σ(z)=1+exp(−z)1​ ) 减小。
  • (                                         −                            log                            ⁡                            (                            σ                            (                            z                            )                            )                                  -\log(\sigma(z))                     −log(σ(z)) ) 增大,即损失 (                                         L                                  L                     L ) 增大。
举例阐明

假设:


  • (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            =                            2                                  r_\theta(x, y_c) = 2                     rθ​(x,yc​)=2 ),(                                                    r                               θ                                      (                            x                            ,                                       y                               r                                      )                            =                            1                                  r_\theta(x, y_r) = 1                     rθ​(x,yr​)=1 ),猜测差值 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                            =                            1                                  r_\theta(x, y_c) - r_\theta(x, y_r) = 1                     rθ​(x,yc​)−rθ​(x,yr​)=1。
  • 环境1:(                                         m                            (                            r                            )                            =                            0                                  m(r) = 0                     m(r)=0 )(无margin):

    • (                                                   z                                  =                                  1                                  −                                  0                                  =                                  1                                          z = 1 - 0 = 1                           z=1−0=1 ),
    • (                                                   σ                                  (                                  1                                  )                                  =                                               1                                                   1                                        +                                        exp                                        ⁡                                        (                                        −                                        1                                        )                                                           ≈                                  0.731                                          \sigma(1) = \frac{1}{1 + \exp(-1)} \approx 0.731                           σ(1)=1+exp(−1)1​≈0.731 ),
    • (                                                   L                                  =                                  −                                  log                                  ⁡                                  (                                  0.731                                  )                                  ≈                                  0.313                                          L = -\log(0.731) \approx 0.313                           L=−log(0.731)≈0.313 )。

  • 环境2:(                                         m                            (                            r                            )                            =                            0.5                                  m(r) = 0.5                     m(r)=0.5 )(中等margin):

    • (                                                   z                                  =                                  1                                  −                                  0.5                                  =                                  0.5                                          z = 1 - 0.5 = 0.5                           z=1−0.5=0.5 ),
    • (                                                   σ                                  (                                  0.5                                  )                                  ≈                                  0.622                                          \sigma(0.5) \approx 0.622                           σ(0.5)≈0.622 ),
    • (                                                   L                                  =                                  −                                  log                                  ⁡                                  (                                  0.622                                  )                                  ≈                                  0.475                                          L = -\log(0.622) \approx 0.475                           L=−log(0.622)≈0.475 )。

  • 环境3:(                                         m                            (                            r                            )                            =                            1                                  m(r) = 1                     m(r)=1 )(高margin):

    • (                                                   z                                  =                                  1                                  −                                  1                                  =                                  0                                          z = 1 - 1 = 0                           z=1−1=0 ),
    • (                                                   σ                                  (                                  0                                  )                                  =                                  0.5                                          \sigma(0) = 0.5                           σ(0)=0.5 ),
    • (                                                   L                                  =                                  −                                  log                                  ⁡                                  (                                  0.5                                  )                                  ≈                                  0.693                                          L = -\log(0.5) \approx 0.693                           L=−log(0.5)≈0.693 )。

可以看到,(                                    m                         (                         r                         )                              m(r)                  m(r) ) 从0增加到1,损失从0.313增加到0.693,验证了更高的margin导致更大的loss。

3. 为什么更大的Loss会导致更大的梯度?

梯度的定义

在神经网络中,梯度是损失函数 (                                    L                              L                  L ) 对模子参数 (                                    θ                              \theta                  θ ) 的偏导数:
                                                    ∇                               θ                                      L                            =                                                   ∂                                  L                                                      ∂                                  θ                                                       \nabla_\theta L = \frac{\partial L}{\partial \theta}                     ∇θ​L=∂θ∂L​
梯度的巨细决定了参数更新的步幅(通过学习率调整)。我们需要分析 (                                    L                              L                  L ) 如何通过 (                                    z                              z                  z ) 影响 (                                    θ                              \theta                  θ )。
计算梯度

损失函数:
                                         L                            =                            −                            log                            ⁡                            (                            σ                            (                            z                            )                            )                            ,                                     z                            =                                       r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                            −                            m                            (                            r                            )                                  L = -\log(\sigma(z)),\quad z = r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                     L=−log(σ(z)),z=rθ​(x,yc​)−rθ​(x,yr​)−m(r)


  • 首先计算 (                                                                       ∂                                     L                                                           ∂                                     z                                                             \frac{\partial L}{\partial z}                        ∂z∂L​ ):

    • (                                                   σ                                  (                                  z                                  )                                  =                                               1                                                   1                                        +                                        exp                                        ⁡                                        (                                        −                                        z                                        )                                                                   \sigma(z) = \frac{1}{1 + \exp(-z)}                           σ(z)=1+exp(−z)1​ ),
    • (                                                                              d                                        σ                                        (                                        z                                        )                                                                d                                        z                                                           =                                  σ                                  (                                  z                                  )                                  ⋅                                  (                                  1                                  −                                  σ                                  (                                  z                                  )                                  )                                          \frac{d\sigma(z)}{dz} = \sigma(z) \cdot (1 - \sigma(z))                           dzdσ(z)​=σ(z)⋅(1−σ(z)) )(sigmoid的导数),
    • (                                                   L                                  =                                  −                                  log                                  ⁡                                  (                                  σ                                  (                                  z                                  )                                  )                                          L = -\log(\sigma(z))                           L=−log(σ(z)) ),
    • (                                                                              ∂                                        L                                                                ∂                                        z                                                           =                                  −                                               1                                                   σ                                        (                                        z                                        )                                                           ⋅                                                             d                                        σ                                        (                                        z                                        )                                                                d                                        z                                                           =                                  −                                                             σ                                        (                                        z                                        )                                        ⋅                                        (                                        1                                        −                                        σ                                        (                                        z                                        )                                        )                                                                σ                                        (                                        z                                        )                                                           =                                  −                                  (                                  1                                  −                                  σ                                  (                                  z                                  )                                  )                                          \frac{\partial L}{\partial z} = -\frac{1}{\sigma(z)} \cdot \frac{d\sigma(z)}{dz} = -\frac{\sigma(z) \cdot (1 - \sigma(z))}{\sigma(z)} = -(1 - \sigma(z))                           ∂z∂L​=−σ(z)1​⋅dzdσ(z)​=−σ(z)σ(z)⋅(1−σ(z))​=−(1−σ(z)) )。

  • 然后计算 (                                                                       ∂                                     z                                                           ∂                                     θ                                                             \frac{\partial z}{\partial \theta}                        ∂θ∂z​ ):

    • (                                                   z                                  =                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                  −                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                  −                                  m                                  (                                  r                                  )                                          z = r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                           z=rθ​(x,yc​)−rθ​(x,yr​)−m(r) ),
    • (                                                                              ∂                                        z                                                                ∂                                        θ                                                           =                                                             ∂                                                       r                                           θ                                                      (                                        x                                        ,                                                       y                                           c                                                      )                                                                ∂                                        θ                                                           −                                                             ∂                                                       r                                           θ                                                      (                                        x                                        ,                                                       y                                           r                                                      )                                                                ∂                                        θ                                                                   \frac{\partial z}{\partial \theta} = \frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}                           ∂θ∂z​=∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​ )((                                                   m                                  (                                  r                                  )                                          m(r)                           m(r) ) 是常数,对 (                                                   θ                                          \theta                           θ ) 无导数)。

  • 综合得梯度:
                                                                ∂                                  L                                                      ∂                                  θ                                                 =                                                   ∂                                  L                                                      ∂                                  z                                                 ⋅                                                   ∂                                  z                                                      ∂                                  θ                                                 =                            −                            (                            1                            −                            σ                            (                            z                            )                            )                            ⋅                                       (                                                        ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        c                                                  )                                                           ∂                                     θ                                                      −                                                        ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        r                                                  )                                                           ∂                                     θ                                                      )                                            \frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial \theta} = -(1 - \sigma(z)) \cdot \left(\frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}\right)                     ∂θ∂L​=∂z∂L​⋅∂θ∂z​=−(1−σ(z))⋅(∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​)
更高的Margin如何影响梯度



  • 当 (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 增加时,(                                         z                                  z                     z ) 减小,(                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) 减小。
  • (                                         1                            −                            σ                            (                            z                            )                                  1 - \sigma(z)                     1−σ(z) ) 增大(因为 (                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) 接近0时,(                                         1                            −                            σ                            (                            z                            )                                  1 - \sigma(z)                     1−σ(z) ) 接近1)。
  • (                                        −                            (                            1                            −                            σ                            (                            z                            )                            )                                  -(1 - \sigma(z))                     −(1−σ(z)) ) 的绝对值增大(负值变动大),使梯度的绝对值 (                                         ∣                                                   ∂                                  L                                                      ∂                                  θ                                                 ∣                                  |\frac{\partial L}{\partial \theta}|                     ∣∂θ∂L​∣ ) 增大。
举例验证

继续上例:


  • (                                         m                            (                            r                            )                            =                            0                                  m(r) = 0                     m(r)=0 ):(                                         z                            =                            1                                  z = 1                     z=1 ),(                                         σ                            (                            1                            )                            ≈                            0.731                                  \sigma(1) \approx 0.731                     σ(1)≈0.731 ),(                                         1                            −                            σ                            (                            1                            )                            ≈                            0.269                                  1 - \sigma(1) \approx 0.269                     1−σ(1)≈0.269 ),

    • 梯度因子 (                                                   −                                  (                                  1                                  −                                  σ                                  (                                  z                                  )                                  )                                  ≈                                  −                                  0.269                                          -(1 - \sigma(z)) \approx -0.269                           −(1−σ(z))≈−0.269 )。

  • (                                         m                            (                            r                            )                            =                            1                                  m(r) = 1                     m(r)=1 ):(                                         z                            =                            0                                  z = 0                     z=0 ),(                                         σ                            (                            0                            )                            =                            0.5                                  \sigma(0) = 0.5                     σ(0)=0.5 ),(                                         1                            −                            σ                            (                            0                            )                            =                            0.5                                  1 - \sigma(0) = 0.5                     1−σ(0)=0.5 ),

    • 梯度因子 (                                                   −                                  (                                  1                                  −                                  σ                                  (                                  z                                  )                                  )                                  =                                  −                                  0.5                                          -(1 - \sigma(z)) = -0.5                           −(1−σ(z))=−0.5 )。

梯度绝对值从0.269增加到0.5,阐明更高的 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 导致更大的梯度。

4. 逻辑总结与直观解释

为什么更高的Margin导致更大的Loss?



  • (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 是一个“门槛”,表示人类期望的夸奖差距。
  • 当 (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 更高时,模子的猜测差值 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_c) - r_\theta(x, y_r)                     rθ​(x,yc​)−rθ​(x,yr​) ) 如果没跟上这个门槛,(                                         z                                  z                     z ) 变小,(                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) 变小,损失变大。
  • 这就像考试:如果及格线从60分进步到80分,而你还是考70分,差距更大,得分(损失的反面)更低。
为什么更大的Loss导致更大的梯度?



  • 损失变大意味着模子当前猜测与目标偏离更多,梯度(误差的导数)天然更大。
  • 更大的梯度推动参数更新更大幅度,使 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                                  r_\theta(x, y_c)                     rθ​(x,yc​) ) 更快增加,(                                                    r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_r)                     rθ​(x,yr​) ) 更快减小,满足更高的 (                                         m                            (                            r                            )                                  m(r)                     m(r) )。
整体逻辑



  • 高 (                                         m                            (                            r                            )                                  m(r)                     m(r) ) → 小 (                                         z                                  z                     z ) → 小 (                                         σ                            (                            z                            )                                  \sigma(z)                     σ(z) ) → 大 (                                         L                                  L                     L ) → 大梯度 → 大更新。
  • 这是Margin Loss的焦点:通过引入偏好程度,放大模子的学习信号,让夸奖差值更好地反映人类的主观判断。

5. 实际意义

在Llama 2中,加入 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 的好处是:


  • 区分偏好强度:不仅知道 (                                                    y                               c                                      >                                       y                               r                                            y_c > y_r                     yc​>yr​ ),还能反映“多好”(如“显著好” vs “略好”)。
  • 加强练习信号:更高的 (                                        m                            (                            r                            )                                  m(r)                     m(r) ) 迫使模子更快调整,确保夸奖差值与人类期望同等。
  • 提升模子质量:更大的梯度加快收敛,使Reward Model更准确地捕获偏好。

6. 总结

Margin Loss通过引入 (                                    m                         (                         r                         )                              m(r)                  m(r) ),将人类偏好的强度融入损失函数。更高的 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 进步了对夸奖差值的要求,若模子未达标,损失增大;损失增大则放大梯度,推动模子更快学习。这个机制既直观又高效,是Llama 2在RLHF中改进Reward Model的关键一步。希望这篇博客能解开你的疑惑!如果还有问题,接待继续探究。
是的,在这个上下文中,“梯度的增大”指的是梯度的绝对值增大。让我详细解释一下这个问题,并以中文博客的情势进一步澄清这个概念,确保你对“梯度增大”的含义及其背后的逻辑有全面的明白。

梯度增大指的是梯度的绝对值增大吗?

在讨论Llama 2中Margin Loss的梯度计算时,我们提到更高的 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 会导致“更大的梯度”。你的问题是:这里的“梯度增大”是否特指梯度的绝对值增大?答案是肯定的。下面我们将渐渐分析这个问题,从数学推导到直观明白,解答你的疑问。

1. 梯度的定义与方向

在神经网络中,梯度 (                                              ∇                            θ                                  L                              \nabla_\theta L                  ∇θ​L ) 是损失函数 (                                    L                              L                  L ) 对模子参数 (                                    θ                              \theta                  θ ) 的偏导数:
                                                    ∇                               θ                                      L                            =                                                   ∂                                  L                                                      ∂                                  θ                                                       \nabla_\theta L = \frac{\partial L}{\partial \theta}                     ∇θ​L=∂θ∂L​


  • 梯度是一个向量:它包罗多个分量,每个分量对应 (                                         θ                                  \theta                     θ ) 中的一个参数。
  • 巨细与方向

    • 巨细:梯度的模(magnitude),即 (                                                   ∣                                               ∇                                     θ                                              L                                  ∣                                  =                                                                            ∑                                           i                                                                                     (                                                                                 ∂                                                    L                                                                                    ∂                                                                       θ                                                       i                                                                                                 )                                                          2                                                                                 |\nabla_\theta L| = \sqrt{\sum_i \left(\frac{\partial L}{\partial \theta_i}\right)^2}                           ∣∇θ​L∣=∑i​(∂θi​∂L​)2               ​ )。
    • 方向:指向损失增加最快的方向。

  • 练习中的作用:优化器(如Adam)使用梯度的负方向((                                         −                                       ∇                               θ                                      L                                  -\nabla_\theta L                     −∇θ​L ))更新参数,以减小损失。
当我们说“梯度增大”时,通常指的是梯度向量的巨细(即绝对值或模)变大,因为这直接影响参数更新的幅度。

2. Margin Loss中的梯度表达式

在Llama 2的Margin Loss中,损失函数为:
                                         L                            =                            −                            log                            ⁡                                       (                               σ                                           (                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                  −                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                  −                                  m                                  (                                  r                                  )                                  )                                          )                                            L = -\log\left(\sigma\left(r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)\right)\right)                     L=−log(σ(rθ​(x,yc​)−rθ​(x,yr​)−m(r)))
定义:
                                         z                            =                                       r                               θ                                      (                            x                            ,                                       y                               c                                      )                            −                                       r                               θ                                      (                            x                            ,                                       y                               r                                      )                            −                            m                            (                            r                            )                                  z = r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                     z=rθ​(x,yc​)−rθ​(x,yr​)−m(r)
梯度计算为:
                                                                ∂                                  L                                                      ∂                                  θ                                                 =                                                   ∂                                  L                                                      ∂                                  z                                                 ⋅                                                   ∂                                  z                                                      ∂                                  θ                                                       \frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial \theta}                     ∂θ∂L​=∂z∂L​⋅∂θ∂z​
其中:


  • (                                                                ∂                                  L                                                      ∂                                  z                                                 =                            −                            (                            1                            −                            σ                            (                            z                            )                            )                                  \frac{\partial L}{\partial z} = -(1 - \sigma(z))                     ∂z∂L​=−(1−σ(z)) )(上一节推导)。
  • (                                                                ∂                                  z                                                      ∂                                  θ                                                 =                                                   ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                                      ∂                                  θ                                                 −                                                   ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                                      ∂                                  θ                                                       \frac{\partial z}{\partial \theta} = \frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}                     ∂θ∂z​=∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​ )。
完备梯度:
                                                                ∂                                  L                                                      ∂                                  θ                                                 =                            −                            (                            1                            −                            σ                            (                            z                            )                            )                            ⋅                                       (                                                        ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        c                                                  )                                                           ∂                                     θ                                                      −                                                        ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        r                                                  )                                                           ∂                                     θ                                                      )                                            \frac{\partial L}{\partial \theta} = -(1 - \sigma(z)) \cdot \left(\frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}\right)                     ∂θ∂L​=−(1−σ(z))⋅(∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​)


  • 梯度因子:(                                         −                            (                            1                            −                            σ                            (                            z                            )                            )                                  -(1 - \sigma(z))                     −(1−σ(z)) ) 是一个标量,始终为负值(因为 (                                         0                            <                            σ                            (                            z                            )                            <                            1                                  0 < \sigma(z) < 1                     0<σ(z)<1 ))。
  • 方向部分:(                                                                ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                                      ∂                                  θ                                                 −                                                   ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                                      ∂                                  θ                                                       \frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}                     ∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​ ) 是一个向量,决定了梯度的方向。

3. 更高的 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 如何影响梯度?

影响梯度的巨细



  • 当 (                                        m                            (                            r                            )                                  m(r)                     m(r) ) 增加时:

    • (                                                   z                                  =                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                  −                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                  −                                  m                                  (                                  r                                  )                                          z = r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)                           z=rθ​(x,yc​)−rθ​(x,yr​)−m(r) ) 减小。
    • (                                                   σ                                  (                                  z                                  )                                          \sigma(z)                           σ(z) ) 减小(sigmoid函数单调递增)。
    • (                                                   1                                  −                                  σ                                  (                                  z                                  )                                          1 - \sigma(z)                           1−σ(z) ) 增大。
    • (                                                   −                                  (                                  1                                  −                                  σ                                  (                                  z                                  )                                  )                                          -(1 - \sigma(z))                           −(1−σ(z)) ) 的绝对值增大(负值的幅度变大)。

比方:


  • (                                         m                            (                            r                            )                            =                            0                                  m(r) = 0                     m(r)=0 ):(                                         z                            =                            1                                  z = 1                     z=1 ),(                                         σ                            (                            1                            )                            ≈                            0.731                                  \sigma(1) \approx 0.731                     σ(1)≈0.731 ),(                                         −                            (                            1                            −                            σ                            (                            1                            )                            )                            ≈                            −                            0.269                                  -(1 - \sigma(1)) \approx -0.269                     −(1−σ(1))≈−0.269 )。
  • (                                         m                            (                            r                            )                            =                            1                                  m(r) = 1                     m(r)=1 ):(                                         z                            =                            0                                  z = 0                     z=0 ),(                                         σ                            (                            0                            )                            =                            0.5                                  \sigma(0) = 0.5                     σ(0)=0.5 ),(                                         −                            (                            1                            −                            σ                            (                            0                            )                            )                            =                            −                            0.5                                  -(1 - \sigma(0)) = -0.5                     −(1−σ(0))=−0.5 )。
标量因子 (                                    −                         (                         1                         −                         σ                         (                         z                         )                         )                              -(1 - \sigma(z))                  −(1−σ(z)) ) 的绝对值从0.269增加到0.5。
梯度的绝对值

梯度的模为:
                                                    ∣                                                        ∂                                     L                                                           ∂                                     θ                                                      ∣                                      =                                       ∣                               −                               (                               1                               −                               σ                               (                               z                               )                               )                               ∣                                      ⋅                                       ∣                                                        ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        c                                                  )                                                           ∂                                     θ                                                      −                                                        ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        r                                                  )                                                           ∂                                     θ                                                      ∣                                            \left|\frac{\partial L}{\partial \theta}\right| = \left|-(1 - \sigma(z))\right| \cdot \left|\frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}\right|                                    ​∂θ∂L​               ​=∣−(1−σ(z))∣⋅               ​∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​               ​


  • (                                         ∣                            −                            (                            1                            −                            σ                            (                            z                            )                            )                            ∣                            =                            1                            −                            σ                            (                            z                            )                                  |-(1 - \sigma(z))| = 1 - \sigma(z)                     ∣−(1−σ(z))∣=1−σ(z) )(因为 (                                         −                            (                            1                            −                            σ                            (                            z                            )                            )                            <                            0                                  -(1 - \sigma(z)) < 0                     −(1−σ(z))<0 ))。
  • (                                                                ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                                      ∂                                  θ                                                 −                                                   ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                                      ∂                                  θ                                                       \frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}                     ∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​ ) 是模子内部计算的梯度向量,其巨细取决于当前参数和输入。
当 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 增加时,(                                    1                         −                         σ                         (                         z                         )                              1 - \sigma(z)                  1−σ(z) ) 增大,直接导致 (                                    ∣                                              ∂                               L                                                 ∂                               θ                                            ∣                              \left|\frac{\partial L}{\partial \theta}\right|                                ​∂θ∂L​              ​ ) 增大。这里的“梯度增大”正是指梯度向量的绝对值(模)变大。
方向是否改变?



  • (                                         −                            (                            1                            −                            σ                            (                            z                            )                            )                                  -(1 - \sigma(z))                     −(1−σ(z)) ) 只影响梯度的巨细(标量缩放),不改变方向。
  • 方向由 (                                                                ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     c                                              )                                                      ∂                                  θ                                                 −                                                   ∂                                               r                                     θ                                              (                                  x                                  ,                                               y                                     r                                              )                                                      ∂                                  θ                                                       \frac{\partial r_\theta(x, y_c)}{\partial \theta} - \frac{\partial r_\theta(x, y_r)}{\partial \theta}                     ∂θ∂rθ​(x,yc​)​−∂θ∂rθ​(x,yr​)​ ) 决定,与 (                                         m                            (                            r                            )                                  m(r)                     m(r) ) 无关。
因此,“梯度增大”特指绝对值增大,方向保持同等。

4. 为什么关注绝对值?

在练习过程中,梯度的巨细(绝对值)决定了参数更新的幅度:


  • 更新公式:(                                         θ                            ←                            θ                            −                            η                            ⋅                                       ∇                               θ                                      L                                  \theta \leftarrow \theta - \eta \cdot \nabla_\theta L                     θ←θ−η⋅∇θ​L )((                                        η                                  \eta                     η) 是学习率)。
  • (                                         ∣                                       ∇                               θ                                      L                            ∣                                  |\nabla_\theta L|                     ∣∇θ​L∣ ) 越大,参数变革越大。
更高的 (                                    m                         (                         r                         )                              m(r)                  m(r) ) 使 (                                    ∣                                   ∇                            θ                                  L                         ∣                              |\nabla_\theta L|                  ∣∇θ​L∣ ) 增大,意味着:


  • 模子感知到当前猜测与人类期望的差距更大。
  • 需要更大幅度调整参数,使 (                                                    r                               θ                                      (                            x                            ,                                       y                               c                                      )                                  r_\theta(x, y_c)                     rθ​(x,yc​) ) 增加,(                                                    r                               θ                                      (                            x                            ,                                       y                               r                                      )                                  r_\theta(x, y_r)                     rθ​(x,yr​) ) 减小,以满足更高的margin。

5. 举例验证

继续之前的例子:


  • (                                                          r                                  θ                                          (                               x                               ,                                           y                                  c                                          )                               =                               2                                      r_\theta(x, y_c) = 2                        rθ​(x,yc​)=2 ),(                                                          r                                  θ                                          (                               x                               ,                                           y                                  r                                          )                               =                               1                                      r_\theta(x, y_r) = 1                        rθ​(x,yr​)=1 )。
  • 假设 (                                                                       ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        c                                                  )                                                           ∂                                     θ                                                      =                               [                               0.1                               ,                               0.2                               ]                                      \frac{\partial r_\theta(x, y_c)}{\partial \theta} = [0.1, 0.2]                        ∂θ∂rθ​(x,yc​)​=[0.1,0.2] ),(                                                                       ∂                                                   r                                        θ                                                  (                                     x                                     ,                                                   y                                        r                                                  )                                                           ∂                                     θ                                                      =                               [                               0.05                               ,                               0.1                               ]                                      \frac{\partial r_\theta(x, y_r)}{\partial \theta} = [0.05, 0.1]                        ∂θ∂rθ​(x,yr​)​=[0.05,0.1] )。
  • (                                                                       ∂                                     z                                                           ∂                                     θ                                                      =                               [                               0.1                               −                               0.05                               ,                               0.2                               −                               0.1                               ]                               =                               [                               0.05                               ,                               0.1                               ]                                      \frac{\partial z}{\partial \theta} = [0.1 - 0.05, 0.2 - 0.1] = [0.05, 0.1]                        ∂θ∂z​=[0.1−0.05,0.2−0.1]=[0.05,0.1] ),模 (                                                                       0.0                                                   5                                        2                                                  +                                     0.                                                   1                                        2                                                                   ≈                               0.112                                      \sqrt{0.05^2 + 0.1^2} \approx 0.112                        0.052+0.12              ​≈0.112 )。
  • (                                              m                               (                               r                               )                               =                               0                                      m(r) = 0                        m(r)=0 ):

    • (                                                   z                                  =                                  1                                          z = 1                           z=1 ),(                                                   −                                  (                                  1                                  −                                  σ                                  (                                  1                                  )                                  )                                  ≈                                  −                                  0.269                                          -(1 - \sigma(1)) \approx -0.269                           −(1−σ(1))≈−0.269 ),
    • (                                                                ∇                                     θ                                              L                                  =                                  −                                  0.269                                  ⋅                                  [                                  0.05                                  ,                                  0.1                                  ]                                  =                                  [                                  −                                  0.01345                                  ,                                  −                                  0.0269                                  ]                                          \nabla_\theta L = -0.269 \cdot [0.05, 0.1] = [-0.01345, -0.0269]                           ∇θ​L=−0.269⋅[0.05,0.1]=[−0.01345,−0.0269] ),
    • 模 (                                                   ∣                                               ∇                                     θ                                              L                                  ∣                                  ≈                                  0.0301                                          |\nabla_\theta L| \approx 0.0301                           ∣∇θ​L∣≈0.0301 )。

  • (                                              m                               (                               r                               )                               =                               1                                      m(r) = 1                        m(r)=1 ):

    • (                                                   z                                  =                                  0                                          z = 0                           z=0 ),(                                                   −                                  (                                  1                                  −                                  σ                                  (                                  0                                  )                                  )                                  =                                  −                                  0.5                                          -(1 - \sigma(0)) = -0.5                           −(1−σ(0))=−0.5 ),
    • (                                                                ∇                                     θ                                              L                                  =                                  −                                  0.5                                  ⋅                                  [                                  0.05                                  ,                                  0.1                                  ]                                  =                                  [                                  −                                  0.025                                  ,                                  −                                  0.05                                  ]                                          \nabla_\theta L = -0.5 \cdot [0.05, 0.1] = [-0.025, -0.05]                           ∇θ​L=−0.5⋅[0.05,0.1]=[−0.025,−0.05] ),
    • 模 (                                                   ∣                                               ∇                                     θ                                              L                                  ∣                                  ≈                                  0.0559                                          |\nabla_\theta L| \approx 0.0559                           ∣∇θ​L∣≈0.0559 )。

梯度模从0.0301增加到0.0559,绝对值确实增大。

6. 直观解释



  • 更高的 (                                              m                               (                               r                               )                                      m(r)                        m(r) ) 像更高的门槛:如果人类说 (                                                    y                               c                                            y_c                     yc​ ) “显著好于” (                                                    y                               r                                            y_r                     yr​ ),模子必须给出更大的夸奖差值。当前差值不足时,损失变大,梯度绝对值随之增大,推动模子“努力”调整。
  • 梯度绝对值决定更新强度:更大的绝对值意味着参数变革更剧烈,帮助模子更快接近目标。

7. 总结

是的,“梯度的增大”在这里指的是梯度的绝对值(模)增大。更高的 (                                   m                         (                         r                         )                              m(r)                  m(r) ) 使 (                                    z                              z                  z ) 减小,(                                    1                         −                         σ                         (                         z                         )                              1 - \sigma(z)                  1−σ(z) ) 增大,梯度因子 (                                    −                         (                         1                         −                         σ                         (                         z                         )                         )                              -(1 - \sigma(z))                  −(1−σ(z)) ) 的绝对值变大,从而使整个梯度向量的模增大。这反映了模子需要更强的更新信号来满足更高的偏好尺度。希望这篇分析能清楚解答你的疑问!
后记

2025年3月1日16点33分于上海,在grok3大模子辅助下完成。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

我爱普洱茶

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表