介绍交叉熵损失(Cross-Entropy Loss)以及交叉熵在对比学习中的应用:中英 ...

打印 上一主题 下一主题

主题 894|帖子 894|积分 2682

中文版

本文解释 交叉熵损失(Cross-Entropy Loss),并团结对比学习的应用说明它如何工作,以及如何让正样本对更近、负样本对更远。

什么是交叉熵损失?

交叉熵损失是呆板学习中常用的一种损失函数,主要用于分类任务,用来衡量模型推测的概率分布和真实分布之间的差别。
其公式为:
                                         L                            =                            −                                       ∑                                           i                                  =                                  1                                          C                                                 y                               i                                      log                            ⁡                            (                                                   y                                  ^                                          i                                      )                                  L = -\sum_{i=1}^C y_i \log(\hat{y}_i)                     L=−i=1∑C​yi​log(y^​i​)


  • (                                         C                                  C                     C ):类别数。
  • (                                                    y                               i                                            y_i                     yi​ ):真实类别分布(通常是独热编码,只有真实类别对应位置为 1)。
  • (                                                                y                                  ^                                          i                                            \hat{y}_i                     y^​i​ ):模型推测的概率分布(通常是通过 softmax 得到的概率值)。
如果只考虑单一样本,交叉熵公式可以简化为:
                                         L                            =                            −                            log                            ⁡                            (                                                   y                                  ^                                          j                                      )                                  L = -\log(\hat{y}_j)                     L=−log(y^​j​)


  • (                                         j                                  j                     j ) 是真实类别的索引。
  • (                                                                y                                  ^                                          j                                            \hat{y}_j                     y^​j​ ) 是模型推测的真实类别概率。

交叉熵损失如何工作?


  • 惩罚错误推测:

    • 如果模型推测的真实类别概率 (                                                                              y                                        ^                                                  j                                                      \hat{y}_j                           y^​j​ ) 较小,则损失 (                                                   −                                  log                                  ⁡                                  (                                                             y                                        ^                                                  j                                              )                                          -\log(\hat{y}_j)                           −log(y^​j​) ) 很大,从而对模型施加较大的惩罚,迫使模型学习更高的真实类别概率。
    • 比方,若 (                                                                              y                                        ^                                                  j                                              =                                  0.1                                          \hat{y}_j = 0.1                           y^​j​=0.1 ),损失 (                                                   −                                  log                                  ⁡                                  (                                  0.1                                  )                                  =                                  2.3                                          -\log(0.1) = 2.3                           −log(0.1)=2.3 )。

  • 奖励精确推测:

    • 如果模型推测的真实类别概率 (                                                                              y                                        ^                                                  j                                                      \hat{y}_j                           y^​j​ ) 较大(靠近 1),损失很小,表现模型在这一样本上的推测靠近抱负。
    • 比方,若 (                                                                              y                                        ^                                                  j                                              =                                  0.9                                          \hat{y}_j = 0.9                           y^​j​=0.9 ),损失 (                                                   −                                  log                                  ⁡                                  (                                  0.9                                  )                                  =                                  0.11                                          -\log(0.9) = 0.11                           −log(0.9)=0.11 )。

  • 鼓励模型信心:

    • 模型推测越靠近 1 或 0(置信度更高),交叉熵的结果会更低,模型学习的效果越好。


Softmax 与交叉熵的关系

交叉熵损失通常和 Softmax 一起使用,Softmax 是将原始的 logits 转换为概率分布的函数:
                                                                y                                  ^                                          i                                      =                                                   exp                                  ⁡                                  (                                               z                                     i                                              )                                                                   ∑                                                   k                                        =                                        1                                                  C                                              exp                                  ⁡                                  (                                               z                                     k                                              )                                                       \hat{y}_i = \frac{\exp(z_i)}{\sum_{k=1}^C \exp(z_k)}                     y^​i​=∑k=1C​exp(zk​)exp(zi​)​
此中:


  • (                                                    z                               i                                            z_i                     zi​ ):模型输出的 logits 值(未归一化的分数)。
  • (                                                                y                                  ^                                          i                                            \hat{y}_i                     y^​i​ ):Softmax 输出的归一化概率。
Softmax 确保输出概率总和为 1,使得它适合作为概率分布与真实标签举行比较。

交叉熵在对比学习中的应用

在对比学习任务(如 CLIP)中,交叉熵损失被用来拉近正样本对的相似度,同时拉远负样本对的相似度。
比方,在 CLIP 模型中:

  • 输入:

    • 一批图像和对应的文本形貌。
    • 模型通过编码器天生图像和文本的嵌入向量 (                                                   z_image                                  ,                                  z_text                                          \text{z\_image}, \text{z\_text}                           z_image,z_text )。

  • 计算 logits(相似度矩阵):

    • 两个向量的相似度通常用点积或余弦相似度计算:
                                                              logits_per_image                                     [                                     i                                     ]                                     [                                     j                                     ]                                     =                                     sim                                     (                                                   z_image                                        i                                                  ,                                                   z_text                                        j                                                  )                                              \text{logits\_per\_image}[j] = \text{sim}(\text{z\_image}_i, \text{z\_text}_j)                              logits_per_image[j]=sim(z_imagei​,z_textj​)

  • 计算概率分布:

    • 使用 softmax 将相似度矩阵的每一行归一化为概率分布,表现图像 (                                                   i                                          i                           i ) 对应文本 (                                                   j                                          j                           j ) 的匹配概率。

  • 交叉熵损失:

    • 对于每个图像 (                                                   i                                          i                           i ),真实匹配文本的索引为 (                                                   j                                          j                           j ),交叉熵损失是:
                                                              L                                     =                                     −                                     log                                     ⁡                                     (                                     P                                     (                                     positive                                     )                                     )                                              L = -\log(P(\text{positive}))                              L=−log(P(positive))
    • (                                                   P                                  (                                  positive                                  )                                          P(\text{positive})                           P(positive) ) 是正样本的 softmax 概率值。


举例说明交叉熵如何拉近正样本,拉远负样本

假设例子:


  • 批量大小 = 3,logits(相似度矩阵):
                                                       logits_per_image                                  =                                               [                                                                                                   2.0                                                                                                           0.5                                                                                                                             −                                                    1.0                                                                                                                                                      0.3                                                                                                           1.8                                                                                                           0.2                                                                                                                                                       −                                                    0.5                                                                                                                            0.4                                                                                                           1.5                                                                                               ]                                                      \text{logits\_per\_image} = \begin{bmatrix} 2.0 & 0.5 & -1.0 \\ 0.3 & 1.8 & 0.2 \\ -0.5 & 0.4 & 1.5 \end{bmatrix}                           logits_per_image=                 ​2.00.3−0.5​0.51.80.4​−1.00.21.5​                 ​
  • Softmax 概率:

    • 第一行(图像 1 的概率分布):
                                                              P                                     (                                     positive                                     )                                     =                                                                  exp                                           ⁡                                           (                                           2.0                                           )                                                                     exp                                           ⁡                                           (                                           2.0                                           )                                           +                                           exp                                           ⁡                                           (                                           0.5                                           )                                           +                                           exp                                           ⁡                                           (                                           −                                           1.0                                           )                                                                ≈                                     0.71                                              P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0)} \approx 0.71                              P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)exp(2.0)​≈0.71
                                                              P                                     (                                     negative                                     ,                                     j                                     =                                     2                                     )                                     ≈                                     0.23                                     ,                                                 P                                     (                                     negative                                     ,                                     j                                     =                                     3                                     )                                     ≈                                     0.06                                              P(\text{negative}, j=2) \approx 0.23, \quad P(\text{negative}, j=3) \approx 0.06                              P(negative,j=2)≈0.23,P(negative,j=3)≈0.06

  • 交叉熵损失:

    • 如果图像 1 和文本 1 是正样本对:
                                                              L                                     =                                     −                                     log                                     ⁡                                     (                                     P                                     (                                     positive                                     )                                     )                                     ≈                                     −                                     log                                     ⁡                                     (                                     0.71                                     )                                     =                                     0.34                                              L = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34                              L=−log(P(positive))≈−log(0.71)=0.34

  • 优化目标:

    • 提高正样本概率: 比方将 logits 中的 ( 2.0 ) 调高。
    • 降低负样本概率: 比方将 logits 中的 ( 0.5, -1.0 ) 调低。


梯度更新

通过反向传播,交叉熵损失会对 logits 施加以下影响:

  • 正样本对: 提升其 logits 值,让正样本的相似度更高。
  • 负样本对: 降低其 logits 值,让负样本的相似度更低。
具体过程请参考笔者的另一篇博客:通过模拟对CLIP举行解释:如何通过梯度提升正样本的相似度?

总结

在对比学习中,交叉熵损失团结 softmax 通过最大化正样本对的概率 (                                    P                         (                         positive                         )                              P(\text{positive})                  P(positive) ) 和最小化负样本对的概率,从而学习到一个区分度更高的嵌入空间。这种方法被广泛应用于大模型(如 CLIP、SimCLR)中,用于学习视觉与文本、差别视角图像等的语义匹配。
英文版


What is Cross-Entropy Loss?

Cross-entropy loss measures the difference between two probability distributions:

  • The true labels’ distribution (ground truth).
  • The predicted probability distribution (from the model, e.g., softmax output).
The formula for cross-entropy loss is:
                                         L                            =                            −                                       ∑                                           i                                  =                                  1                                          C                                                 y                               i                                      log                            ⁡                            (                                                   y                                  ^                                          i                                      )                                  L = -\sum_{i=1}^C y_i \log(\hat{y}_i)                     L=−i=1∑C​yi​log(y^​i​)
Where:


  • (                                         C                                  C                     C ): The number of classes.
  • (                                                    y                               i                                            y_i                     yi​ ): The true label for class (                                         i                                  i                     i ) (1 if true, 0 otherwise, in one-hot encoding).
  • (                                                                y                                  ^                                          i                                            \hat{y}_i                     y^​i​ ): The predicted probability for class (                                         i                                  i                     i ) (output of the softmax).
For a single example, if the ground truth class is (                                    j                              j                  j ), the loss simplifies to:
                                         L                            =                            −                            log                            ⁡                            (                                                   y                                  ^                                          j                                      )                                  L = -\log(\hat{y}_j)                     L=−log(y^​j​)

How Does Cross-Entropy Work?


  • Penalizes incorrect predictions:

    • If the model predicts a probability far from the true class ((                                                                              y                                        ^                                                  j                                                      \hat{y}_j                           y^​j​ ) is small), the loss is high because (                                                   −                                  log                                  ⁡                                  (                                                             y                                        ^                                                  j                                              )                                          -\log(\hat{y}_j)                           −log(y^​j​) ) is large.
    • Example: If (                                                                              y                                        ^                                                  j                                              =                                  0.1                                          \hat{y}_j = 0.1                           y^​j​=0.1 ), then (                                                   −                                  log                                  ⁡                                  (                                  0.1                                  )                                  =                                  2.3                                          -\log(0.1) = 2.3                           −log(0.1)=2.3 ).

  • Rewards correct predictions:

    • If the model predicts a high probability for the true class ((                                                                              y                                        ^                                                  j                                                      \hat{y}_j                           y^​j​ ) close to 1), the loss is small.
    • Example: If (                                                                              y                                        ^                                                  j                                              =                                  0.9                                          \hat{y}_j = 0.9                           y^​j​=0.9 ), then (                                                   −                                  log                                  ⁡                                  (                                  0.9                                  )                                  =                                  0.11                                          -\log(0.9) = 0.11                           −log(0.9)=0.11 ).

  • Encourages probabilistic confidence:

    • Predictions close to 0 or 1 result in higher confidence and a lower loss.


Connection to Softmax

Cross-entropy loss is typically used after a softmax activation function, which normalizes raw logits into probabilities:
                                                                y                                  ^                                          i                                      =                                                   exp                                  ⁡                                  (                                               z                                     i                                              )                                                                   ∑                                                   k                                        =                                        1                                                  C                                              exp                                  ⁡                                  (                                               z                                     k                                              )                                                       \hat{y}_i = \frac{\exp(z_i)}{\sum_{k=1}^C \exp(z_k)}                     y^​i​=∑k=1C​exp(zk​)exp(zi​)​
Where:


  • (                                                    z                               i                                            z_i                     zi​ ): The raw score (logit) for class (                                         i                                  i                     i ).
  • (                                                                y                                  ^                                          i                                            \hat{y}_i                     y^​i​ ): The predicted probability for class (                                         i                                  i                     i ).
The softmax ensures that the output probabilities sum to 1, making them suitable for comparing to one-hot encoded true labels.

How It Maximizes Positive Class Similarity (Contrastive Setting)

In contrastive learning (e.g., CLIP), cross-entropy loss is used to pull positive pairs closer together while pushing negative pairs apart. Here’s how it works:

  • Positive Pair Similarity:

    • If the predicted similarity for the positive pair (e.g., (                                                                              y                                        ^                                                  positive                                                      \hat{y}_{\text{positive}}                           y^​positive​ )) is high, (                                                   −                                  log                                  ⁡                                  (                                                             y                                        ^                                                  positive                                              )                                          -\log(\hat{y}_{\text{positive}})                           −log(y^​positive​) ) is small, reducing the loss. This encourages the model to further increase the similarity.

  • Negative Pair Similarity:

    • For negative pairs, their probabilities are part of the denominator in the softmax:
                                                              P                                     (                                     positive                                     )                                     =                                                                  exp                                           ⁡                                           (                                           sim                                           (                                           pos                                           )                                           )                                                                     exp                                           ⁡                                           (                                           sim                                           (                                           pos                                           )                                           )                                           +                                           ∑                                           exp                                           ⁡                                           (                                           sim                                           (                                           neg                                           )                                           )                                                                         P(\text{positive}) = \frac{\exp(\text{sim}(\text{pos}))}{\exp(\text{sim}(\text{pos})) + \sum \exp(\text{sim}(\text{neg}))}                              P(positive)=exp(sim(pos))+∑exp(sim(neg))exp(sim(pos))​
    • Increasing (                                                   exp                                  ⁡                                  (                                  sim                                  (                                  neg                                  )                                  )                                          \exp(\text{sim}(\text{neg}))                           exp(sim(neg)) ) reduces (                                                   P                                  (                                  positive                                  )                                          P(\text{positive})                           P(positive) ), increasing the loss. Therefore, the model learns to lower the similarity for negative pairs.

By optimizing the cross-entropy loss, the model dynamically adjusts logits to maximize the positive pair similarity while minimizing the negative pair similarity.
后记

2024年12月13日22点11分于上海,在GPT4o大模型辅助下完成。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

络腮胡菲菲

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表