超大规模分类(一):噪声对比估计(Noise Contrastive Estimation, NCE) ...

打印 上一主题 下一主题

主题 1048|帖子 1048|积分 3144

NCE损失对应的论文为《A fast and simple algorithm for training neural probabilistic language models》,发表于2012年的ICML会议。
配景

在2012年,语言模型一样平常接纳n-gram的方法,统计单词/上下文间的共现关系,比神经概率语言模型(neural probabilistic language models, NPLMs)效果好。
现在主流的语言模型都是神经概率语言模 型,核心思想是已知上下文                                   h                              h                  h,预测下一个词为                                   w                              w                  w的概率,通过肯定的解码方法(比方greedy search、beam search等),对概率做解码,得到下一个词。Greedy search可以理解为选择概率最大的那个词。
2012年神经概率语言模型效果不好的原因是难训练。一方面自然是硬件的制约,那一年英伟达刚发布GTX680,和现在的A100、H100完全没法比。当时老黄不给力,学术界也没办法;另一方面是算法服从不可,难以进行大规模的分类学习,将”已知上下文                                   h                              h                  h,预测下一个词为                                             w                            i                                       w_i                  wi​的概率“建模因素类学习任务,目标在于把下一个词分类到词表中的某个词上。
举个例子,已知上下文是“我想去”,必要预测下一个词。词表中有4个词,即['北京','上海','天津','广州'],必要把下一个词归类到词表的4个词里。如果词表有10万个词呢?训不动啊~
这就是当时面临的困境。NCE对分类算法做了优化,使得对大词表做分类任务成为可能。
原理

通俗的配景讲完了,接下来谈谈公式化的原理部门。
问题建模

已知上下文                                   h                              h                  h,预测下一次词为                                   w                              w                  w的概率为:
                                                                                                     P                                        θ                                        h                                                  (                                     w                                     )                                     =                                                                  e                                           x                                           p                                           (                                                           s                                              θ                                                          (                                           w                                           ,                                           h                                           )                                           )                                                                                     ∑                                                               w                                                 i                                                                                          e                                              x                                              p                                              (                                                               s                                                 θ                                                              (                                                               w                                                 i                                                              ,                                              h                                              )                                              )                                                                                                                      (1)                                                       P_{\theta}^h(w)=\frac{exp(s_{\theta}(w,h))}{\sum_{w_i}{exp(s_{\theta}(w_i,h))}}\tag{1}                     Pθh​(w)=∑wi​​exp(sθ​(wi​,h))exp(sθ​(w,h))​(1)
其中,                                             s                            θ                                  (                         w                         ,                         h                         )                              s_{\theta}(w,h)                  sθ​(w,h)表示已知上下文                                   h                              h                  h,下一个词为                                   w                              w                  w的预测得分;                                             ∑                                       w                               i                                                 \sum_{w_i}                  ∑wi​​表示词表内的全部词。
一样平常情况下,                                             s                            θ                                  (                         w                         ,                         h                         )                              s_{\theta}(w,h)                  sθ​(w,h)通过对上下文                                   h                              h                  h表征以及词种别                                   w                              w                  w表征添加多个全连接层盘算得到。最简单的策略,仅对上下文                                   h                              h                  h表征                                             f                            h                                       f_h                  fh​用一个全连接层                                   W                              W                  W做一次映射,再和词种别                                   w                              w                  w表征                                             f                                       w                               i                                                 f_{w_i}                  fwi​​做点积即可。
                                              s                            θ                                  (                         w                         ,                         h                         )                         =                         (                                   f                            h                                  W                         )                         ⋅                                   f                            w                                       s_{\theta}(w,h)=(f_h W) \cdot f_{w}                  sθ​(w,h)=(fh​W)⋅fw​
难度分析

对公式(1)进行分析,
分子部门                                   e                         x                         p                         (                                   s                            θ                                  (                         w                         ,                         h                         )                         )                              exp(s_{\theta}(w,h))                  exp(sθ​(w,h))是好算的,针对单个                                   w                              w                  w,只必要盘算一次。
分母部门KaTeX parse error: \tag works only in display equations不好算,针对单个                                   w                              w                  w,必要盘算                                   e                         x                         p                         (                                   s                            θ                                  (                                   w                            1                                  ,                         h                         )                         )                         ,                         e                         x                         p                         (                                   s                            θ                                  (                                   w                            2                                  ,                         h                         )                         )                         ,                         .                         .                         .                         e                         x                         p                         (                                   s                            θ                                  (                                   w                            n                                  ,                         h                         )                         )                              exp(s_{\theta}(w_1,h)), exp(s_{\theta}(w_2,h)), ...exp(s_{\theta}(w_n,h))                  exp(sθ​(w1​,h)),exp(sθ​(w2​,h)),...exp(sθ​(wn​,h)),如果词表中词很多,盘算量不小。
现在学术界、工业界对超大规模分类的优化根本上都聚焦在怎样优化分母上,比方InfoNCE仅关注batch内的负类样本、KNN softmax对种别聚类,减少种别数目、partial FC对种别做采样以及显存均分来较少盘算量、Inf-CL借助FlashAttention的思想,以空间换时间。
优化策略

既然对词表内n个词的大规模分类任务难做,难办,那就掀桌子不办了!!!

将原多分类任务转换成一个更容易实现的任务——新二分类任务。
除了有正常的真实数据之外,从一个噪声分布里采样噪声数据,对真实数据和噪声数据做二分类,可以证实:随着噪声数据越多,转换后任务的优化目标和转换前任务越接近
新二分类任务

给定上下文                                   h                              h                  h后,现在有两个数据分布,一个是真实数据分布                                             P                            d                            h                                  (                         w                         )                              P_d^h(w)                  Pdh​(w)(现实应该写成                                             P                            d                                  (                         w                         ∣                         h                         )                              P_d(w|h)                  Pd​(w∣h),简化情势写成                                             P                            d                            h                                  (                         w                         )                              P_d^h(w)                  Pdh​(w)),另一个是噪声数据分布                                             P                            n                                  (                         w                         )                              P_n(w)                  Pn​(w),真实数据和噪声数据的比例是1:k。以是,训练数据的完备分布是                                             P                            h                                  (                         w                         )                         =                                   1                                       k                               +                               1                                                      P                            d                            h                                  (                         w                         )                         +                                   k                                       k                               +                               1                                                      P                            n                                  (                         w                         )                              P^h(w)=\frac{1}{k+1}P_d^h(w)+\frac{k}{k+1}P_n(w)                  Ph(w)=k+11​Pdh​(w)+k+1k​Pn​(w),训练任务是                                   D                         =                         1                              D=1                  D=1(分辨真实数据)和                                   D                         =                         0                              D=0                  D=0(分辨噪声数据)。
我们盼望优化神经网络参数                                   θ                              \theta                  θ,来拟合真实数据分布                                             P                            d                            h                                  (                         w                         )                         =                                   P                            θ                            h                                  (                         w                         )                              P_d^h(w)=P^h_{\theta}(w)                  Pdh​(w)=Pθh​(w),后者就是我们学到的数据分布                                             P                            θ                            h                                  (                         w                         )                              P^h_{\theta}(w)                  Pθh​(w),于是,训练数据的完备分布写成                                             P                            h                                  (                         w                         ,                         θ                         )                         =                                   1                                       k                               +                               1                                                      P                            θ                            h                                  (                         w                         )                         +                                   k                                       k                               +                               1                                                      P                            n                                  (                         w                         )                              P^h(w,\theta)=\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)                  Ph(w,θ)=k+11​Pθh​(w)+k+1k​Pn​(w)
训练目标一样平常是最大化后验概率                                             P                            h                                  (                         D                         ∣                         w                         ,                         θ                         )                              P^h(D|w,\theta)                  Ph(D∣w,θ)的对数似然期望                                   E                                   [                            l                            o                            g                            (                                       P                               h                                      (                            D                            ∣                            w                            ,                            θ                            )                            )                            ]                                       E \left[log(P^h(D|w,\theta))\right]                  E[log(Ph(D∣w,θ))],必要盘算后验概率                                             P                            h                                  (                         D                         ∣                         w                         ,                         θ                         )                              P^h(D|w,\theta)                  Ph(D∣w,θ)。
                                                                                                     P                                        h                                                  (                                     D                                     ∣                                     w                                     ,                                     θ                                     )                                     =                                                   P                                        h                                                  (                                     D                                     =                                     1                                     ∣                                     w                                     ,                                     θ                                     )                                     +                                                   P                                        h                                                  (                                     D                                     =                                     0                                     ∣                                     w                                     ,                                     θ                                     )                                                                            (2)                                                       P^h(D|w,\theta)=P^h(D=1|w,\theta)+P^h(D=0|w,\theta)\tag{2}                     Ph(D∣w,θ)=Ph(D=1∣w,θ)+Ph(D=0∣w,θ)(2)
真实数据分布的后验概率为:
                                                                                                                                                                                                                                                                      P                                                                h                                                                                  (                                                             D                                                             =                                                             1                                                             ∣                                                             w                                                             ,                                                             θ                                                             )                                                                                                                                                                                             =                                                                                                                                  P                                                                      h                                                                                          (                                                                   w                                                                   ,                                                                   θ                                                                   ∣                                                                   D                                                                   =                                                                   1                                                                   )                                                                                                                                     P                                                                      h                                                                                          (                                                                   w                                                                   ,                                                                   θ                                                                   )                                                                                                                              P                                                                h                                                                                  (                                                             D                                                             =                                                             1                                                             )                                                                                                                                                                                                                                                                                                                      =                                                                                                                                  P                                                                      θ                                                                      h                                                                                          (                                                                   w                                                                   )                                                                                                                                     1                                                                                               k                                                                         +                                                                         1                                                                                                                                          P                                                                      θ                                                                      h                                                                                          (                                                                   w                                                                   )                                                                   +                                                                                           k                                                                                               k                                                                         +                                                                         1                                                                                                                                          P                                                                      n                                                                                          (                                                                   w                                                                   )                                                                                                                              1                                                                                       k                                                                   +                                                                   1                                                                                                                                                                                                                                                                                                                                                                 =                                                                                                                                  P                                                                      θ                                                                      h                                                                                          (                                                                   w                                                                   )                                                                                                                                     P                                                                      θ                                                                      h                                                                                          (                                                                   w                                                                   )                                                                   +                                                                   k                                                                                           P                                                                      n                                                                                          (                                                                   w                                                                   )                                                                                                                                                                                                                                                                                     (3)                                                       \begin{equation}\begin{aligned} P^h(D=1|w,\theta) &= \frac{P^h(w,\theta|D=1)}{P^h(w,\theta)}P^h(D=1) \\ &=\frac{P_{\theta}^h(w)}{\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)}\frac{1}{k+1} \\ &=\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)} \end{aligned} \end{equation} \tag{3}                     Ph(D=1∣w,θ)​=Ph(w,θ)Ph(w,θ∣D=1)​Ph(D=1)=k+11​Pθh​(w)+k+1k​Pn​(w)Pθh​(w)​k+11​=Pθh​(w)+kPn​(w)Pθh​(w)​​​(3)
我们来看看等式为什么成立


  • 边缘概率                                                   P                               h                                      (                            w                            ,                            θ                            )                            =                                       1                                           k                                  +                                  1                                                            P                               θ                               h                                      (                            w                            )                            +                                       k                                           k                                  +                                  1                                                            P                               n                                      (                            w                            )                                  P^h(w,\theta)=\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)                     Ph(w,θ)=k+11​Pθh​(w)+k+1k​Pn​(w)
  • 先验概率                                                   P                               h                                      (                            D                            =                            1                            )                            =                                       1                                           k                                  +                                  1                                                       P^h(D=1)=\frac{1}{k+1}                     Ph(D=1)=k+11​,原因是真实数据和噪声数据的比例是1:k。
  • 似然函数                                                         P                                  h                                          (                               w                               ,                               θ                               ∣                               D                               =                               1                               )                               =                                           P                                  θ                                  h                                          (                               w                               )                                      P^h(w,\theta|D=1)=P^h_{\theta}(w)                        Ph(w,θ∣D=1)=Pθh​(w),表明在真实数据分布下,从词表里预测下一个词为                                             w                                      w                        w的概率是                                                         P                                  θ                                  h                                          (                               w                               )                                      P^h_{\theta}(w)                        Pθh​(w),这就是我们想拟合的函数。
类似的,噪声数据分布的后验概率为:
                                                                                                                                                                                                                                                                      P                                                                h                                                                                  (                                                             D                                                             =                                                             0                                                             ∣                                                             w                                                             ,                                                             θ                                                             )                                                                                                                                                                                             =                                                                                                                                  P                                                                      h                                                                                          (                                                                   w                                                                   ,                                                                   θ                                                                   ∣                                                                   D                                                                   =                                                                   0                                                                   )                                                                                                                                     P                                                                      h                                                                                          (                                                                   w                                                                   ,                                                                   θ                                                                   )                                                                                                                              P                                                                h                                                                                  (                                                             D                                                             =                                                             0                                                             )                                                                                                                                                                                                                                                                                                                      =                                                                                                                                  P                                                                      n                                                                                          (                                                                   w                                                                   )                                                                                                                                     1                                                                                               k                                                                         +                                                                         1                                                                                                                                          P                                                                      θ                                                                      h                                                                                          (                                                                   w                                                                   )                                                                   +                                                                                           k                                                                                               k                                                                         +                                                                         1                                                                                                                                          P                                                                      n                                                                                          (                                                                   w                                                                   )                                                                                                                              k                                                                                       k                                                                   +                                                                   1                                                                                                                                                                                                                                                                                                                                                                 =                                                                                                          k                                                                                           P                                                                      n                                                                                          (                                                                   w                                                                   )                                                                                                                                     P                                                                      θ                                                                      h                                                                                          (                                                                   w                                                                   )                                                                   +                                                                   k                                                                                           P                                                                      n                                                                                          (                                                                   w                                                                   )                                                                                                                                                                                                                                                                                     (4)                                                       \begin{equation}\begin{aligned} P^h(D=0|w,\theta) &= \frac{P^h(w,\theta|D=0)}{P^h(w,\theta)}P^h(D=0) \\ &=\frac{P_n(w)}{\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)}\frac{k}{k+1} \\ &=\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)} \end{aligned} \end{equation} \tag{4}                     Ph(D=0∣w,θ)​=Ph(w,θ)Ph(w,θ∣D=0)​Ph(D=0)=k+11​Pθh​(w)+k+1k​Pn​(w)Pn​(w)​k+1k​=Pθh​(w)+kPn​(w)kPn​(w)​​​(4)
后验概率                                             P                            h                                  (                         D                         ∣                                   w                            i                                  ,                         θ                         )                              P^h(D|w_i,\theta)                  Ph(D∣wi​,θ)的对数似然的期望                                   E                                   [                            l                            o                            g                            (                                       P                               h                                      (                            D                            ∣                                       w                               i                                      ,                            θ                            )                            )                            ]                                       E \left[log(P^h(D|w_i,\theta))\right]                  E[log(Ph(D∣wi​,θ))]为
                                                                                                                                                                                                                                                                      J                                                                h                                                                                  (                                                             θ                                                             )                                                                                                                                                                                             =                                                             E                                                                                   [                                                                l                                                                o                                                                g                                                                (                                                                                       P                                                                   h                                                                                      (                                                                D                                                                ∣                                                                w                                                                ,                                                                θ                                                                )                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                                       P                                                                   h                                                                                      (                                                                D                                                                =                                                                1                                                                ∣                                                                w                                                                ,                                                                θ                                                                )                                                                ]                                                                                  +                                                                                   E                                                                                       P                                                                   n                                                                                                                              [                                                                l                                                                o                                                                g                                                                                       P                                                                   h                                                                                      (                                                                D                                                                =                                                                0                                                                ∣                                                                w                                                                ,                                                                θ                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                                                                                        P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                             ]                                                                                  +                                                                                   E                                                                                       P                                                                   n                                                                                                                              [                                                                l                                                                o                                                                g                                                                                                               k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                             ]                                                                                                                                                                                                                                                               (5)                                                       \begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[logP^h(D=1|w,\theta)\right] +E_{P_n}\left[logP^h(D=0|w,\theta)\right] \\ &= E_{P_d^h}\left[log\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\right] +E_{P_n}\left[log\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\right] \\ \end{aligned} \end{equation} \tag{5}                     Jh(θ)​=E[log(Ph(D∣w,θ))]=EPdh​​[logPh(D=1∣w,θ)]+EPn​​[logPh(D=0∣w,θ)]=EPdh​​[logPθh​(w)+kPn​(w)Pθh​(w)​]+EPn​​[logPθh​(w)+kPn​(w)kPn​(w)​]​​(5)
我们来算一下梯度,即是
                                                                                                                                                                                                                                                                      ∂                                                                                       ∂                                                                   θ                                                                                                                                                     J                                                                   h                                                                                      (                                                                θ                                                                )                                                                                                                                                                                                                  =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                                               k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                    ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                  −                                                                                                                                                                                                                                                                                                                      k                                                                                   E                                                                                       P                                                                   n                                                                                                                              [                                                                                                                                        P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                    ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                               (6)                                                       \begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= E_{P_d^h}\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] -\\&kE_{P_n}\left[\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] \end{aligned} \end{equation} \tag{6}                     ∂θ∂​Jh(θ)​=EPdh​​[Pθh​(w)+kPn​(w)kPn​(w)​∂θ∂​logPθh​(w)]−kEPn​​[Pθh​(w)+kPn​(w)Pθh​(w)​∂θ∂​logPθh​(w)]​​(6)
对(6)式做化简,有
                                                                                                                                                                                                                                                                      ∂                                                                                       ∂                                                                   θ                                                                                                                                                     J                                                                   h                                                                                      (                                                                θ                                                                )                                                                                                                                                                                                                  =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                                               k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                                               w                                                                         i                                                                                              )                                                                                                                                    ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                  −                                                                                                                                                                                                                                                                                                                      k                                                                                   E                                                                                       P                                                                   n                                                                                                                              [                                                                                                                                        P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                    ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   ∑                                                                w                                                                                                        [                                                                                       P                                                                   d                                                                   h                                                                                      ⋅                                                                                                               k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                    ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                −                                                                                                                                                                                                                                                                                                                                                                 k                                                                                       P                                                                   n                                                                                      ⋅                                                                                                                                        P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                    ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   ∑                                                                w                                                                                                        [                                                                                                               k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                             ×                                                                                                                                                                                                                                                                                                                                                                 (                                                                                       P                                                                   d                                                                   h                                                                                      (                                                                w                                                                )                                                                −                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                )                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                               (7)                                                       \begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= E_{P_d^h}\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w_i)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] -\\&kE_{P_n}\left[\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right]\\ &=\sum_w\left[P_d^h\cdot\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)-\right.\\ &\left. kP_{n}\cdot\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right]\\ &=\sum_w\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\times\right.\\ &\left. (P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right] \end{aligned} \end{equation} \tag{7}                     ∂θ∂​Jh(θ)​=EPdh​​[Pθh​(w)+kPn​(wi​)kPn​(w)​∂θ∂​logPθh​(w)]−kEPn​​[Pθh​(w)+kPn​(w)Pθh​(w)​∂θ∂​logPθh​(w)]=w∑​[Pdh​⋅Pθh​(w)+kPn​(w)kPn​(w)​∂θ∂​logPθh​(w)−kPn​⋅Pθh​(w)+kPn​(w)Pθh​(w)​∂θ∂​logPθh​(w)]=w∑​[Pθh​(w)+kPn​(w)kPn​(w)​×(Pdh​(w)−Pθh​(w))∂θ∂​logPθh​(w)]​​(7)
当噪声数据量级巨大,                                   k                         →                         ∞                              k\to \infty                  k→∞ ,                                                        k                                           P                                  n                                          (                               w                               )                                                             P                                  θ                                  h                                          (                               w                               )                               +                               k                                           P                                  n                                          (                               w                               )                                            →                         1                              \frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\to1                  Pθh​(w)+kPn​(w)kPn​(w)​→1 ,有
                                                                                                                                                                                                                                                                      ∂                                                                                       ∂                                                                   θ                                                                                                                                                     J                                                                   h                                                                                      (                                                                θ                                                                )                                                                                                                                                                                                                  =                                                                                   ∑                                                                w                                                                                                        [                                                                                                               k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                                                           P                                                                         θ                                                                         h                                                                                              (                                                                      w                                                                      )                                                                      +                                                                      k                                                                                               P                                                                         n                                                                                              (                                                                      w                                                                      )                                                                                                             ×                                                                                                                                                                                                                                                                                                                                                                 (                                                                                       P                                                                   d                                                                   h                                                                                      (                                                                w                                                                )                                                                −                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                )                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           →                                                                                   ∑                                                                w                                                                                                        [                                                                (                                                                                       P                                                                   d                                                                   h                                                                                      (                                                                w                                                                )                                                                −                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                )                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                             l                                                                o                                                                g                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                               (8)                                                       \begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= \sum_w\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\times\right.\\ &\left. (P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right]\\ &\to \sum_w\left[(P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right] \end{aligned} \end{equation} \tag{8}                     ∂θ∂​Jh(θ)​=w∑​[Pθh​(w)+kPn​(w)kPn​(w)​×(Pdh​(w)−Pθh​(w))∂θ∂​logPθh​(w)]→w∑​[(Pdh​(w)−Pθh​(w))∂θ∂​logPθh​(w)]​​(8)
原多分类任务

我们盘算下原多分类任务的对数似然期望和梯度,看看                                   k                         →                         ∞                              k\to \infty                  k→∞ 时的新二分类任务和原多分类任务有什么关系。原多分类任务的优化目标为
                                                                                                                                                                                                                                                                      J                                                                h                                                                                  (                                                             θ                                                             )                                                                                                                                                                                             =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                (                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                                       (                                                                                                                    e                                                                         x                                                                         p                                                                         (                                                                                                   s                                                                            θ                                                                                                  (                                                                         w                                                                         ,                                                                         h                                                                         )                                                                         )                                                                                                                                                 ∑                                                                            w                                                                                                                            e                                                                            x                                                                            p                                                                            (                                                                                                       s                                                                               θ                                                                                                      (                                                                            w                                                                            ,                                                                            h                                                                            )                                                                            )                                                                                                                                           )                                                                                      ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                                       (                                                                                           ∑                                                                      w                                                                                                                  e                                                                      x                                                                      p                                                                                               (                                                                                                   s                                                                            θ                                                                                                  (                                                                         w                                                                         ,                                                                         h                                                                         )                                                                         )                                                                                                                  )                                                                                      ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                             l                                                             o                                                             g                                                                                   (                                                                                       ∑                                                                   w                                                                                                             e                                                                   x                                                                   p                                                                                           (                                                                                               s                                                                         θ                                                                                              (                                                                      w                                                                      ,                                                                      h                                                                      )                                                                      )                                                                                                             )                                                                                                                                                                                                                                                               (9)                                                       \begin{equation}\begin{aligned} J^h(\theta)&=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] \\ &= E_{P_d^h} \left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right]\\ &=E_{P_d^h}\left[s_{\theta}(w,h)\right]-E_{P_d^h}\left[log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\right]\\ &=E_{P_d^h}\left[s_{\theta}(w,h)\right]-log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right) \end{aligned} \end{equation} \tag{9}                     Jh(θ)​=EPdh​​[log(Pθh​(w)]=EPdh​​[log(∑w​exp(sθ​(w,h))exp(sθ​(w,h))​)]=EPdh​​[sθ​(w,h)]−EPdh​​[log(w∑​exp(sθ​(w,h)))]=EPdh​​[sθ​(w,h)]−log(w∑​exp(sθ​(w,h)))​​(9)
等式最后一步成立的原因是                                   [                         l                         o                         g                                   (                                       ∑                               w                                                 e                               x                               p                                           (                                               s                                     θ                                              (                                  w                                  ,                                  h                                  )                                  )                                                 )                                  ]                              \left[log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\right]                  [log(∑w​exp(sθ​(w,h)))]仅和模型预测分布                                             P                            θ                            h                                       P_{\theta}^h                  Pθh​有关,和真实数据分布                                             P                            d                            h                                       P_d^h                  Pdh​无关。
对(9)式求梯度,有                                                                                                                                                                                                                                                                     ∂                                                                                       ∂                                                                   θ                                                                                                                              J                                                                h                                                                                  (                                                             θ                                                             )                                                                                                                                                                                             =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   ∂                                                                                       ∂                                                                   θ                                                                                                        l                                                             o                                                             g                                                                                   (                                                                                       ∑                                                                   w                                                                                                             e                                                                   x                                                                   p                                                                                           (                                                                                               s                                                                         θ                                                                                              (                                                                      w                                                                      ,                                                                      h                                                                      )                                                                      )                                                                                                             )                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   1                                                                                                               ∑                                                                      w                                                                                                                  e                                                                      x                                                                      p                                                                                               (                                                                                                   s                                                                            θ                                                                                                  (                                                                         w                                                                         ,                                                                         h                                                                         )                                                                         )                                                                                                                                                                             ∂                                                                                       ∂                                                                   θ                                                                                                                              ∑                                                                w                                                                                                        e                                                                x                                                                p                                                                                       (                                                                                           s                                                                      θ                                                                                          (                                                                   w                                                                   ,                                                                   h                                                                   )                                                                   )                                                                                                                                                                                                                                                                                                                                                                 =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   1                                                                                                               ∑                                                                      w                                                                                                                  e                                                                      x                                                                      p                                                                                               (                                                                                                   s                                                                            θ                                                                                                  (                                                                         w                                                                         ,                                                                         h                                                                         )                                                                         )                                                                                                                                                                             ∑                                                                w                                                                                                        (                                                                                       s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                )                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   ∑                                                                w                                                                                                                                                       s                                                                      θ                                                                                          (                                                                   w                                                                   ,                                                                   h                                                                   )                                                                                                                                     ∑                                                                      w                                                                                                                  e                                                                      x                                                                      p                                                                                               (                                                                                                   s                                                                            θ                                                                                                  (                                                                         w                                                                         ,                                                                         h                                                                         )                                                                         )                                                                                                                                                                             ∂                                                                                       ∂                                                                   θ                                                                                                                              s                                                                θ                                                                                  (                                                             w                                                             ,                                                             h                                                             )                                                                                                                                                                                                                                                                                                                      =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   ∑                                                                w                                                                                                        P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                                                   ∂                                                                                       ∂                                                                   θ                                                                                                                              s                                                                θ                                                                                  (                                                             w                                                             ,                                                             h                                                             )                                                                                                                                                                                                                                                                                                                      =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                                       ∂                                                                                           ∂                                                                      θ                                                                                                                                    s                                                                   θ                                                                                      (                                                                w                                                                ,                                                                h                                                                )                                                                ]                                                                                  −                                                                                   ∑                                                                w                                                                                                        P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                                                   ∂                                                                                       ∂                                                                   θ                                                                                                                              s                                                                θ                                                                                  (                                                             w                                                             ,                                                             h                                                             )                                                                                                                                                                                                                                                                                                                      =                                                                                   ∑                                                                w                                                                                                        P                                                                d                                                                h                                                                                                        ∂                                                                                       ∂                                                                   θ                                                                                                                              s                                                                θ                                                                                  (                                                             w                                                             ,                                                             h                                                             )                                                             −                                                                                   ∑                                                                w                                                                                                        P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                                                   ∂                                                                                       ∂                                                                   θ                                                                                                                              s                                                                θ                                                                                  (                                                             w                                                             ,                                                             h                                                             )                                                                                                                                                                                                                                                                                                                      =                                                                                   ∑                                                                w                                                                                  (                                                                                   P                                                                d                                                                h                                                                                  (                                                             w                                                             )                                                             −                                                                                   P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                             )                                                                                   ∂                                                                                       ∂                                                                   θ                                                                                                                              s                                                                θ                                                                                  (                                                             w                                                             ,                                                             h                                                             )                                                                                                                                                                                                                                          (10)                                                       \begin{equation}\begin{aligned} \frac{\partial}{\partial\theta}J^h(\theta)&=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{\partial}{\partial\theta}log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{1}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\frac{\partial}{\partial\theta}\sum_w{exp\left(s_{\theta}(w,h)\right)}\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{1}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\sum_w\left(s_{\theta}(w,h)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_w\frac{s_{\theta}(w,h)}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=\sum_wP_d^h\frac{\partial}{\partial\theta}s_{\theta}(w,h)-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=\sum_w(P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ \end{aligned} \end{equation} \tag{10}                     ∂θ∂​Jh(θ)​=EPdh​​[∂θ∂​sθ​(w,h)]−∂θ∂​log(w∑​exp(sθ​(w,h)))=EPdh​​[∂θ∂​sθ​(w,h)]−∑w​exp(sθ​(w,h))1​∂θ∂​w∑​exp(sθ​(w,h))=EPdh​​[∂θ∂​sθ​(w,h)]−∑w​exp(sθ​(w,h))1​w∑​(sθ​(w,h)∂θ∂​sθ​(w,h))=EPdh​​[∂θ∂​sθ​(w,h)]−w∑​∑w​exp(sθ​(w,h))sθ​(w,h)​∂θ∂​sθ​(w,h)=EPdh​​[∂θ∂​sθ​(w,h)]−w∑​Pθh​(w)∂θ∂​sθ​(w,h)=EPdh​​[∂θ∂​sθ​(w,h)]−w∑​Pθh​(w)∂θ∂​sθ​(w,h)=w∑​Pdh​∂θ∂​sθ​(w,h)−w∑​Pθh​(w)∂θ∂​sθ​(w,h)=w∑​(Pdh​(w)−Pθh​(w))∂θ∂​sθ​(w,h)​​(10)对比公式(8)和公式(10),很像,但不一样。公式(8)最后是                                             ∂                                       ∂                               θ                                            l                         o                         g                                   P                            θ                            h                                  (                         w                         )                              \frac{\partial}{\partial\theta}logP_{\theta}^h(w)                  ∂θ∂​logPθh​(w),公式(10)最后是                                             ∂                                       ∂                               θ                                                      s                            θ                                  (                         w                         ,                         h                         )                              \frac{\partial}{\partial\theta}s_{\theta}(w,h)                  ∂θ∂​sθ​(w,h),咋回事?
不一样就对了,在NCE中,我们可以将                                             ∑                            w                                            e                            x                            p                                       (                                           s                                  θ                                          (                               w                               ,                               h                               )                               )                                                 \sum_w{exp\left(s_{\theta}(w,h)\right)}                  ∑w​exp(sθ​(w,h))等价成1,那公式(8)和公式(10)就一样了。那为什么可以等价呢?论文的说辞是:                                                        模型参数较多,把正则项当做常数,公式中其他项,比如                                           s                                  θ                                          ,能学到正则项。                                                 \textcolor{red}{{模型参数较多,把正则项当做常数,公式中其他项,比如s_{\theta},能学到正则项。}}                  模型参数较多,把正则项当做常数,公式中其他项,比如sθ​,能学到正则项。(正则项可以理解为                                             ∑                            w                                            e                            x                            p                                       (                                           s                                  θ                                          (                               w                               ,                               h                               )                               )                                                 \sum_w{exp\left(s_{\theta}(w,h)\right)}                  ∑w​exp(sθ​(w,h))),那么                                             ∑                            w                                            e                            x                            p                                       (                                           s                                  θ                                          (                               w                               ,                               h                               )                               )                                                 \sum_w{exp\left(s_{\theta}(w,h)\right)}                  ∑w​exp(sθ​(w,h))是1也好,100也好,都不会对模型收敛有影响。简单起见,当做1就行。

这段说辞还是太抽象了,有没有形象一点的表明?
两个任务为什么可以等价

原多分类任务

                                                                                                                                                                                                                                                                      J                                                                h                                                                                  (                                                             θ                                                             )                                                                                                                                                                                             =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                (                                                                                       P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                ]                                                                                                                                                                                                                                                                                                                                           =                                                                                   E                                                                                       P                                                                   d                                                                   h                                                                                                                              [                                                                l                                                                o                                                                g                                                                                       (                                                                                                                    e                                                                         x                                                                         p                                                                         (                                                                                                   s                                                                            θ                                                                                                  (                                                                         w                                                                         ,                                                                         h                                                                         )                                                                         )                                                                                                                                                 ∑                                                                            w                                                                                                                            e                                                                            x                                                                            p                                                                            (                                                                                                       s                                                                               θ                                                                                                      (                                                                            w                                                                            ,                                                                            h                                                                            )                                                                            )                                                                                                                                           )                                                                                      ]                                                                                                                                                                                                                                                               (11)                                                       \begin{equation}\begin{aligned} J^h(\theta)&=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] \\ &= E_{P_d^h} \left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right] \end{aligned} \end{equation} \tag{11}                     Jh(θ)​=EPdh​​[log(Pθh​(w)]=EPdh​​[log(∑w​exp(sθ​(w,h))exp(sθ​(w,h))​)]​​(11)
该任务的对数似然期望见公式(11),                                   l                         o                         g                              log                  log函数曲线如下:

如果                                   l                         o                         g                         (                                   P                            θ                            h                                  (                         w                         )                         =                         e                         x                         p                         (                                   s                            θ                                  (                         w                         ,                         h                         )                         )                         ∈                         [                         0                         ,                         +                         ∞                         ]                              log(P_{\theta}^h(w)=exp(s_{\theta}(w,h))\in[0,+\infty]                  log(Pθh​(w)=exp(sθ​(w,h))∈[0,+∞],                                             J                            h                                  (                         θ                         )                         =                                   E                                       P                               d                               h                                                      [                            l                            o                            g                            (                                       P                               θ                               h                                      (                            w                            )                            ]                                       J^h(\theta)=E_{P_d^h} \left[log(P_{\theta}^h(w)\right]                  Jh(θ)=EPdh​​[log(Pθh​(w)]不存在极值,无法收敛。
如果对                                   l                         o                         g                         (                                   P                            θ                            h                                  (                         w                         )                         =                         e                         x                         p                         (                                   s                            θ                                  (                         w                         ,                         h                         )                         )                         ∈                         [                         0                         ,                         +                         ∞                         ]                              log(P_{\theta}^h(w)=exp(s_{\theta}(w,h))\in[0,+\infty]                  log(Pθh​(w)=exp(sθ​(w,h))∈[0,+∞]进行归一化,                                   l                         o                         g                         (                                   P                            θ                            h                                  (                         w                         )                         =                                   [                            l                            o                            g                                       (                                                        e                                     x                                     p                                     (                                                   s                                        θ                                                  (                                     w                                     ,                                     h                                     )                                     )                                                                         ∑                                        w                                                                e                                        x                                        p                                        (                                                       s                                           θ                                                      (                                        w                                        ,                                        h                                        )                                        )                                                                   )                                      ]                                  ∈                         (                         0                         ,                         1                         )                              log(P_{\theta}^h(w)=\left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right]\in(0,1)                  log(Pθh​(w)=[log(∑w​exp(sθ​(w,h))exp(sθ​(w,h))​)]∈(0,1),                                             J                            h                                  (                         θ                         )                         =                                   E                                       P                               d                               h                                                      [                            l                            o                            g                            (                                       P                               θ                               h                                      (                            w                            )                            ]                                       J^h(\theta)=E_{P_d^h} \left[log(P_{\theta}^h(w)\right]                  Jh(θ)=EPdh​​[log(Pθh​(w)]存在极值,具备收敛条件。
现二分类任务

从公式(5)可知,
                                                                                                                                                                                          J                                                       h                                                                      (                                                    θ                                                    )                                                                                                                                                               =                                                    E                                                                       [                                                       l                                                       o                                                       g                                                       (                                                                           P                                                          h                                                                          (                                                       D                                                       ∣                                                       w                                                       ,                                                       θ                                                       )                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                       l                                                       o                                                       g                                                                           P                                                          h                                                                          (                                                       D                                                       =                                                       1                                                       ∣                                                       w                                                       ,                                                       θ                                                       )                                                       ]                                                                      +                                                                       E                                                                           P                                                          n                                                                                                            [                                                       l                                                       o                                                       g                                                                           P                                                          h                                                                          (                                                       D                                                       =                                                       0                                                       ∣                                                       w                                                       ,                                                       θ                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                       l                                                       o                                                       g                                                                                                                      P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                                                                                         P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                             +                                                             k                                                                                   P                                                                n                                                                                  (                                                             w                                                             )                                                                                              ]                                                                      +                                                                       E                                                                           P                                                          n                                                                                                            [                                                       l                                                       o                                                       g                                                                                                k                                                                                   P                                                                n                                                                                  (                                                             w                                                             )                                                                                                                         P                                                                θ                                                                h                                                                                  (                                                             w                                                             )                                                             +                                                             k                                                                                   P                                                                n                                                                                  (                                                             w                                                             )                                                                                              ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                       l                                                       o                                                       g                                                       (                                                       σ                                                       (                                                       Δ                                                       )                                                       )                                                       ]                                                                      +                                                    k                                                                       E                                                                           P                                                          n                                                                                                            [                                                       l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       )                                                       ]                                                                                                                                                                                    \begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[logP^h(D=1|w,\theta)\right] +E_{P_n}\left[logP^h(D=0|w,\theta)\right] \\ &= E_{P_d^h}\left[log\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\right] +E_{P_n}\left[log\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\right] \\ &= E_{P_d^h}\left[log(\sigma({\Delta}))\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right] \\ \end{aligned} \tag{12}\end{equation}                     Jh(θ)​=E[log(Ph(D∣w,θ))]=EPdh​​[logPh(D=1∣w,θ)]+EPn​​[logPh(D=0∣w,θ)]=EPdh​​[logPθh​(w)+kPn​(w)Pθh​(w)​]+EPn​​[logPθh​(w)+kPn​(w)kPn​(w)​]=EPdh​​[log(σ(Δ))]+kEPn​​[log(1−σ(Δ))]​​(12)​
,其中                                   Δ                         =                         l                         o                         g                                   P                            θ                            h                                  (                         w                         )                         −                         l                         o                         g                         k                                   P                            n                                  (                         w                         )                              \Delta=logP_{\theta}^h(w)-logkP_n(w)                  Δ=logPθh​(w)−logkPn​(w),将公式(5)推导成具备                                   σ                              \sigma                  σ的公式(12),原因在于求导方便,                                             ∂                                       ∂                               x                                            σ                         (                         x                         )                         =                         σ                         (                         x                         )                         (                         1                         −                         σ                         (                         x                         )                         )                              \frac{\partial}{\partial x}\sigma(x)=\sigma(x)(1-\sigma(x))                  ∂x∂​σ(x)=σ(x)(1−σ(x)),将公式(5)推导成公式(12)的过程是:
                                                                                                                                                                                                              P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                                                                             P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                       +                                                       k                                                                           P                                                          n                                                                          (                                                       w                                                       )                                                                                                                                                                                 =                                                                       1                                                                           1                                                          +                                                                                                     k                                                                                       P                                                                   n                                                                                      (                                                                w                                                                )                                                                                                                               P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                                                                                                                                                                                                                                                                                                 =                                                                       1                                                                           1                                                          +                                                          e                                                          x                                                          p                                                          (                                                          l                                                          o                                                          g                                                          (                                                                                                     k                                                                                       P                                                                   n                                                                                      (                                                                w                                                                )                                                                                                                               P                                                                   θ                                                                   h                                                                                      (                                                                w                                                                )                                                                                                   )                                                          )                                                                                                                                                                                                                                                                                                        =                                                                       1                                                                           1                                                          +                                                          e                                                          x                                                          p                                                          (                                                          l                                                          o                                                          g                                                          k                                                                               P                                                             n                                                                              (                                                          w                                                          )                                                          −                                                          l                                                          o                                                          g                                                                               P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                          )                                                                                                                                                                                                                                                                                                        =                                                                       1                                                                           1                                                          +                                                          e                                                          x                                                          p                                                          (                                                          −                                                          (                                                          l                                                          o                                                          g                                                                               P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                          −                                                          l                                                          o                                                          g                                                          k                                                                               P                                                             n                                                                              (                                                          w                                                          )                                                          )                                                          )                                                                                                                                                                                                                                                                                                        =                                                    σ                                                    (                                                    l                                                    o                                                    g                                                                       P                                                       θ                                                       h                                                                      (                                                    w                                                    )                                                    −                                                    l                                                    o                                                    g                                                    k                                                                       P                                                       n                                                                      (                                                    w                                                    )                                                    )                                                                                                                                                                  \begin{equation}\begin{aligned} \frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}&=\frac{1}{1+\frac{kP_n(w)}{P_{\theta}^h(w)}}\\ &=\frac{1}{1+exp(log(\frac{kP_n(w)}{P_{\theta}^h(w)}))}\\ &=\frac{1}{1+exp(logkP_n(w)-logP_{\theta}^h(w))}\\ &=\frac{1}{1+exp(-(logP_{\theta}^h(w)-logkP_n(w)))}\\ &=\sigma(logP_{\theta}^h(w)-logkP_n(w))\\ \end{aligned} \tag{12}\end{equation}                     Pθh​(w)+kPn​(w)Pθh​(w)​​=1+Pθh​(w)kPn​(w)​1​=1+exp(log(Pθh​(w)kPn​(w)​))1​=1+exp(logkPn​(w)−logPθh​(w))1​=1+exp(−(logPθh​(w)−logkPn​(w)))1​=σ(logPθh​(w)−logkPn​(w))​​(12)​
                                                                                                                                                                                          k                                                                           P                                                          n                                                                          (                                                       w                                                       )                                                                                                             P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                       +                                                       k                                                                           P                                                          n                                                                          (                                                       w                                                       )                                                                                                                                                                                 =                                                    1                                                    −                                                                                                                P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                                                                                   P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                          +                                                          k                                                                               P                                                             n                                                                              (                                                          w                                                          )                                                                                                                                                                                                                                                                                                        =                                                    1                                                    −                                                    σ                                                    (                                                    l                                                    o                                                    g                                                                       P                                                       θ                                                       h                                                                      (                                                    w                                                    )                                                    −                                                    l                                                    o                                                    g                                                    k                                                                       P                                                       n                                                                      (                                                    w                                                    )                                                    )                                                                                                                                                                  \begin{equation}\begin{aligned} \frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}&=1-\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\\ &=1-\sigma(logP_{\theta}^h(w)-logkP_n(w))\\ \end{aligned} \tag{13}\end{equation}                     Pθh​(w)+kPn​(w)kPn​(w)​​=1−Pθh​(w)+kPn​(w)Pθh​(w)​=1−σ(logPθh​(w)−logkPn​(w))​​(13)​
于是,盘算对数似然均值(公式(12))对                                   l                         o                         g                                   P                            θ                            h                                  (                         w                         )                              logP_{\theta}^h(w)                  logPθh​(w)的一阶导,有
                                                                                                                                                                                          ∂                                                                           J                                                          h                                                                          (                                                       θ                                                       )                                                                                         ∂                                                       l                                                       o                                                       g                                                                           P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                                                                                                                                                 =                                                                                           ∂                                                                               J                                                             h                                                                              (                                                          θ                                                          )                                                                                              ∂                                                          Δ                                                                                                                                ∂                                                          Δ                                                                                              ∂                                                          l                                                          o                                                          g                                                                               P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                                                                                                                                                                                                                                                                        =                                                                                           ∂                                                                               J                                                             h                                                                              (                                                          θ                                                          )                                                                                              ∂                                                          Δ                                                                                                                                                                                                                                                                                                        =                                                                       ∂                                                                           ∂                                                          Δ                                                                                                            {                                                                           E                                                                               P                                                             d                                                             h                                                                                                                  [                                                          l                                                          o                                                          g                                                          (                                                          σ                                                          (                                                          Δ                                                          )                                                          )                                                          ]                                                                          +                                                       k                                                                           E                                                                               P                                                             n                                                                                                                  [                                                          l                                                          o                                                          g                                                          (                                                          1                                                          −                                                          σ                                                          (                                                          Δ                                                          )                                                          )                                                          ]                                                                          }                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                                           ∂                                                                               ∂                                                             Δ                                                                                              l                                                       o                                                       g                                                       (                                                       σ                                                       (                                                       Δ                                                       )                                                       )                                                       ]                                                                      +                                                    k                                                                       E                                                                           P                                                          n                                                                                                            [                                                                           ∂                                                                               ∂                                                             Δ                                                                                              l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                       1                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       ]                                                                      +                                                    k                                                                       E                                                                           P                                                          n                                                                                                            [                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       ∑                                                       w                                                                                         P                                                       θ                                                       h                                                                      (                                                    w                                                    )                                                    (                                                    1                                                    −                                                    σ                                                    (                                                    Δ                                                    )                                                    )                                                    −                                                    k                                                                       P                                                       n                                                                      (                                                    w                                                    )                                                    σ                                                    (                                                    Δ                                                    )                                                                                                                                                                  \begin{equation}\begin{aligned} \frac{\partial J^h(\theta)}{\partial logP_{\theta}^h(w)} &=\frac{\partial J^h(\theta)}{\partial \Delta}\frac{\partial \Delta}{\partial logP_{\theta}^h(w)}\\ &=\frac{\partial J^h(\theta)}{\partial \Delta}\\ &=\frac{\partial }{\partial \Delta}\left\{E_{P_d^h}\left[log(\sigma({\Delta}))\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right]\right\}\\ &=E_{P_d^h}\left[\frac{\partial }{\partial \Delta}log(\sigma({\Delta}))\right] +kE_{P_n}\left[\frac{\partial }{\partial \Delta}log(1-\sigma({\Delta}))\right]\\ &=E_{P_d^h}\left[1-\sigma({\Delta})\right] +kE_{P_n}\left[-\sigma({\Delta})\right]\\ &=\sum_wP_{\theta}^h(w)(1-\sigma({\Delta}))-kP_n(w)\sigma({\Delta})\\ \end{aligned} \tag{14}\end{equation}                     ∂logPθh​(w)∂Jh(θ)​​=∂Δ∂Jh(θ)​∂logPθh​(w)∂Δ​=∂Δ∂Jh(θ)​=∂Δ∂​{EPdh​​[log(σ(Δ))]+kEPn​​[log(1−σ(Δ))]}=EPdh​​[∂Δ∂​log(σ(Δ))]+kEPn​​[∂Δ∂​log(1−σ(Δ))]=EPdh​​[1−σ(Δ)]+kEPn​​[−σ(Δ)]=w∑​Pθh​(w)(1−σ(Δ))−kPn​(w)σ(Δ)​​(14)​
如果                                             P                            θ                            h                                  (                         w                         )                         =                                   P                            d                            h                                  (                         w                         )                              P_{\theta}^h(w)=P_d^h(w)                  Pθh​(w)=Pdh​(w),对数似然均值达到极大值(这个是废话,由于训练目标就是盼望                                             P                            θ                            h                                  (                         w                         )                         →                                   P                            d                            h                                  (                         w                         )                              P_{\theta}^h(w)\to P_d^h(w)                  Pθh​(w)→Pdh​(w),并且在优化策略章节开始部门,我们就让                                             P                            θ                            h                                  (                         w                         )                         =                                   P                            d                            h                                  (                         w                         )                              P_{\theta}^h(w)= P_d^h(w)                  Pθh​(w)=Pdh​(w))其中                                             P                            d                            h                                  (                         w                         )                              P_d^h(w)                  Pdh​(w)表示真实分布。
我们再盘算对数似然均值(公式(12))对                                   l                         o                         g                                   P                            θ                            h                                  (                         w                         )                              logP_{\theta}^h(w)                  logPθh​(w)的二阶导,有:
                                                                                                                                                                                                              ∂                                                          2                                                                                              J                                                          h                                                                          (                                                       θ                                                       )                                                                                         ∂                                                       l                                                       o                                                                           g                                                          2                                                                                              P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                                                                                                                                                 =                                                                                                                ∂                                                             2                                                                              J                                                          (                                                          θ                                                          )                                                                                              ∂                                                                               Δ                                                             2                                                                                                                                                                                                                                                                                                                            =                                                                       ∂                                                                           ∂                                                          Δ                                                                                                            {                                                                           E                                                                               P                                                             d                                                             h                                                                                                                  [                                                          1                                                          −                                                          σ                                                          (                                                          Δ                                                          )                                                          ]                                                                          +                                                       k                                                                           E                                                                               P                                                             n                                                                                                                  [                                                          −                                                          σ                                                          (                                                          Δ                                                          )                                                          ]                                                                          }                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            ∂                                                                           ∂                                                          Δ                                                                                                            [                                                       1                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       ]                                                                      +                                                    k                                                                       E                                                                           P                                                          n                                                                                                            ∂                                                                           ∂                                                          Δ                                                                                                            [                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                         [                                                    −                                                    σ                                                    (                                                    Δ                                                    )                                                    (                                                    1                                                    −                                                    σ                                                    (                                                    Δ                                                    )                                                    )                                                    ]                                                    +                                                    k                                                                       E                                                                           P                                                          n                                                                                         [                                                    −                                                    σ                                                    (                                                    Δ                                                    )                                                    (                                                    1                                                    −                                                    σ                                                    (                                                    Δ                                                    )                                                    )                                                    ]                                                                                                                                                                  \begin{equation}\begin{aligned} \frac{\partial^2 J^h(\theta)}{\partial log^2P_{\theta}^h(w)} &=\frac{\partial^2J(\theta)}{\partial \Delta^2}\\ &=\frac{\partial}{\partial \Delta} \left\{E_{P_d^h}\left[1-\sigma({\Delta})\right] +kE_{P_n}\left[-\sigma({\Delta})\right] \right\} \\ &= E_{P_d^h}\frac{\partial}{\partial \Delta}\left[1- \sigma({\Delta})\right] +kE_{P_n}\frac{\partial}{\partial \Delta}\left[-\sigma({\Delta})\right] \\ &= E_{P_d^h}[-\sigma(\Delta)(1-\sigma(\Delta))] +kE_{P_n}[-\sigma(\Delta)(1-\sigma(\Delta))] \\ \end{aligned} \tag{14}\end{equation}                     ∂log2Pθh​(w)∂2Jh(θ)​​=∂Δ2∂2J(θ)​=∂Δ∂​{EPdh​​[1−σ(Δ)]+kEPn​​[−σ(Δ)]}=EPdh​​∂Δ∂​[1−σ(Δ)]+kEPn​​∂Δ∂​[−σ(Δ)]=EPdh​​[−σ(Δ)(1−σ(Δ))]+kEPn​​[−σ(Δ)(1−σ(Δ))]​​(14)​
由于                                   [                         −                         σ                         (                         Δ                         )                         (                         1                         −                         σ                         (                         Δ                         )                         )                         ]                              [-\sigma(\Delta)(1-\sigma(\Delta))]                  [−σ(Δ)(1−σ(Δ))]始终小于0,以是二阶导始终小于0,阐明新二分类任务的对数似然均值是关于                                   l                         o                         g                                   P                            θ                            h                                  (                         w                         )                              logP_{\theta}^h(w)                  logPθh​(w)的凸函数,有唯一极大值。以是极大值肯定是                                             P                            θ                            h                                  (                         w                         )                         =                                   P                            h                                  (                         w                         )                              P_{\theta}^h(w)=P^h(w)                  Pθh​(w)=Ph(w)。
最紧张的是,整个推导过程对是否必要归一化没有要求,既然没有要求,直接让                                                   ∑                               w                                                 e                               x                               p                                           (                                               s                                     θ                                              (                                  w                                  ,                                  h                                  )                                  )                                                 =                            1                                  \sum_w{exp\left(s_{\theta}(w,h)\right)}=1                     ∑w​exp(sθ​(w,h))=1
代码实现

从公式(12),我们可以知道:                                   Δ                         =                         l                         o                         g                                   P                            θ                            h                                  (                         w                         )                         −                         l                         o                         g                         k                                   P                            n                                  (                         w                         )                              \Delta=logP_{\theta}^h(w)-logkP_n(w)                  Δ=logPθh​(w)−logkPn​(w)
                                                                                                                                                                                          J                                                       h                                                                      (                                                    θ                                                    )                                                                                                                                                               =                                                    E                                                                       [                                                       l                                                       o                                                       g                                                       (                                                                           P                                                          h                                                                          (                                                       D                                                       ∣                                                       w                                                       ,                                                       θ                                                       )                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                       l                                                       o                                                       g                                                       σ                                                       (                                                       Δ                                                       )                                                       ]                                                                      +                                                    k                                                                       E                                                                           P                                                          n                                                                                                            [                                                       l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                       Δ                                                       )                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       E                                                                           P                                                          d                                                          h                                                                                                            [                                                       l                                                       o                                                       g                                                       σ                                                       (                                                       l                                                       o                                                       g                                                                           P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                       −                                                       l                                                       o                                                       g                                                       k                                                                           P                                                          n                                                                          (                                                       w                                                       )                                                       )                                                       ]                                                                      +                                                                                                                                                                                                                                                                                                                                                                         k                                                                       E                                                                           P                                                          n                                                                                                            [                                                       l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                       l                                                       o                                                       g                                                                           P                                                          θ                                                          h                                                                          (                                                       w                                                       )                                                       −                                                       l                                                       o                                                       g                                                       k                                                                           P                                                          n                                                                          (                                                       w                                                       )                                                       )                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                                       ∑                                                       w                                                                                         {                                                                           P                                                          d                                                          h                                                                                              [                                                          l                                                          o                                                          g                                                          σ                                                          (                                                          l                                                          o                                                          g                                                                               P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                          −                                                          l                                                          o                                                          g                                                          k                                                                               P                                                             n                                                                              (                                                          w                                                          )                                                          )                                                          ]                                                                          }                                                                      +                                                                                                                                                                                                                                                                                                                                                                         k                                                                       ∑                                                       w                                                                                         {                                                                           P                                                          n                                                                                              [                                                          l                                                          o                                                          g                                                          (                                                          1                                                          −                                                          σ                                                          (                                                          l                                                          o                                                          g                                                                               P                                                             θ                                                             h                                                                              (                                                          w                                                          )                                                          −                                                          l                                                          o                                                          g                                                          k                                                                               P                                                             n                                                                              (                                                          w                                                          )                                                          )                                                          )                                                          ]                                                                          }                                                                                                                                                                                                                                                                                     →                                                    l                                                    o                                                    g                                                    (                                                    σ                                                    (                                                    l                                                    o                                                    g                                                                       P                                                       θ                                                       h                                                                      (                                                                       w                                                       0                                                                      )                                                    −                                                    l                                                    o                                                    g                                                    k                                                                       P                                                       n                                                                      (                                                                       w                                                       0                                                                      )                                                    )                                                    +                                                                                                                                                                                                                                                                                                                                                                                            ∑                                                                           i                                                          =                                                          1                                                                          k                                                                                         [                                                       l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                       l                                                       o                                                       g                                                                           P                                                          θ                                                          h                                                                          (                                                                           w                                                          i                                                                          )                                                       −                                                       l                                                       o                                                       g                                                       k                                                                           P                                                          n                                                                          (                                                                           w                                                          i                                                                          )                                                       )                                                       )                                                       ]                                                                                                                                                                                                                                                                                     =                                                    l                                                    o                                                    g                                                    (                                                    σ                                                    (                                                                       s                                                       θ                                                                      (                                                                       w                                                       0                                                                      ,                                                    h                                                    )                                                    −                                                    l                                                    o                                                    g                                                    k                                                                       P                                                       n                                                                      (                                                                       w                                                       0                                                                      )                                                    )                                                    +                                                                                                                                                                                                                                                                                                                                                                                            ∑                                                                           i                                                          =                                                          1                                                                          k                                                                                         [                                                       l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                                           s                                                          θ                                                                          (                                                                           w                                                          i                                                                          ,                                                       h                                                       )                                                       −                                                       l                                                       o                                                       g                                                       k                                                                           P                                                          n                                                                          (                                                                           w                                                          i                                                                          )                                                       )                                                       )                                                       ]                                                                                                                                                                                    \begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[log\sigma({\Delta})\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right] \\ &= E_{P_d^h}\left[log\sigma(logP_{\theta}^h(w)-logkP_n(w))\right] +\\ &\quad\quad\quad\quad\quad\quad kE_{P_n}\left[log(1-\sigma(logP_{\theta}^h(w)-logkP_n(w)))\right] \\ &= \sum_w\left\{P_d^h\left[log\sigma(logP_{\theta}^h(w)-logkP_n(w))\right] \right\}+\\ &\quad\quad\quad\quad\quad\quad k\sum_w\left\{P_n\left[log(1-\sigma(logP_{\theta}^h(w)-logkP_n(w)))\right]\right\} \\ &\to log(\sigma(logP_{\theta}^h(w_0)-logkP_n(w_0)) +\\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(logP_{\theta}^h(w_i)-logkP_n(w_i)))\right] \\ &=log(\sigma(s_{\theta}(w_0,h)-logkP_n(w_0)) +\\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(s_{\theta}(w_i,h)-logkP_n(w_i)))\right] \\ \end{aligned} \tag{15}\end{equation}                     Jh(θ)​=E[log(Ph(D∣w,θ))]=EPdh​​[logσ(Δ)]+kEPn​​[log(1−σ(Δ))]=EPdh​​[logσ(logPθh​(w)−logkPn​(w))]+kEPn​​[log(1−σ(logPθh​(w)−logkPn​(w)))]=w∑​{Pdh​[logσ(logPθh​(w)−logkPn​(w))]}+kw∑​{Pn​[log(1−σ(logPθh​(w)−logkPn​(w)))]}→log(σ(logPθh​(w0​)−logkPn​(w0​))+i=1∑k​[log(1−σ(logPθh​(wi​)−logkPn​(wi​)))]=log(σ(sθ​(w0​,h)−logkPn​(w0​))+i=1∑k​[log(1−σ(sθ​(wi​,h)−logkPn​(wi​)))]​​(15)​
详细实现时,正样本项仅思量目标class,负样本项随机选择k个样本,通过蒙特卡洛来模拟抽样。
那终极损失函数代码应该怎么写呢?
                                                                                                                                                                       l                                                    o                                                    s                                                    s                                                                                                                                                               =                                                    −                                                                       J                                                       h                                                                      (                                                    θ                                                    )                                                                                                                                                                                                                                                                   =                                                    −                                                    l                                                    o                                                    g                                                    (                                                    σ                                                    (                                                                       s                                                       θ                                                                      (                                                                       w                                                       0                                                                      ,                                                    h                                                    )                                                    −                                                    l                                                    o                                                    g                                                    k                                                                       P                                                       n                                                                      (                                                                       w                                                       0                                                                      )                                                    )                                                    )                                                    −                                                                                                                                                                                                                                                                                                                                                                                            ∑                                                                           i                                                          =                                                          1                                                                          k                                                                                         [                                                       l                                                       o                                                       g                                                       (                                                       1                                                       −                                                       σ                                                       (                                                                           s                                                          θ                                                                          (                                                                           w                                                          i                                                                          ,                                                       h                                                       )                                                       −                                                       l                                                       o                                                       g                                                       k                                                                           P                                                          n                                                                          (                                                                           w                                                          i                                                                          )                                                       )                                                       )                                                       ]                                                                                                                                                                                    \begin{equation}\begin{aligned} loss &= -J^h(\theta) \\ &=-log(\sigma(s_{\theta}(w_0,h)-logkP_n(w_0))) - \\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(s_{\theta}(w_i,h)-logkP_n(w_i)))\right] \\ \end{aligned} \tag{16}\end{equation}                     loss​=−Jh(θ)=−log(σ(sθ​(w0​,h)−logkPn​(w0​)))−i=1∑k​[log(1−σ(sθ​(wi​,h)−logkPn​(wi​)))]​​(16)​
公式(16)中有四个项输入,分别是


  •                                                    s                               θ                                      (                                       w                               0                                      ,                            h                            )                                  s_{\theta}(w_0,h)                     sθ​(w0​,h),目标class的logit
  •                                                    P                               n                                      (                                       w                               0                                      )                                  P_n(w_0)                     Pn​(w0​),目标class的噪声分布
  •                                                    s                               θ                                      (                                       w                               i                                      ,                            h                            )                                  s_{\theta}(w_i,h)                     sθ​(wi​,h),噪声class的logit
  •                                                    P                               n                                      (                                       w                               i                                      )                                  P_n(w_i)                     Pn​(wi​),噪声class的噪声分布
  1. from torch import randn, tensor, log, multinomial
  2. import torch.nn.functional as F
  3. from einops import repeat
  4. import torch
  5. import math
  6. bs,k=2,8
  7. num_classes=16
  8. #构造噪声:按照类别的频率采样
  9. #(噪声分布约等于实际数据分布,两个分布越接近,nce效果越好)
  10. classes=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
  11. class_freq=tensor([20,10,30,5,45,56,76,43,23,11,34,5,6,54,23,7])
  12. class_probs=class_freq/class_freq.sum()
  13. noise_classes=multinomial(class_probs, num_classes)
  14. #模型预测的logits
  15. logits=randn(bs, num_classes)
  16. #2个样本的标签
  17. labels=tensor([2, 4])
  18. #目标class的logit
  19. true_class_logits=logits.take_along_dim(labels[:, None], dim=1)
  20. #目标class的噪声分布
  21. true_class_noise=class_probs[labels]
  22. #噪声class的logit
  23. logits_k = repeat(logits, '(b 1) h -> (b k) h', k=k)
  24. noise_class_logits = logits_k.take_along_dim(noise_classes.reshape(bs * k, -1), dim=1)
  25. #噪声class的噪声分布
  26. noise_class_noise=class_probs[noise_classes]
  27. #nce loss计算
  28. true_class_loss = -torch.log( F.sigmoid(true_class_logits - torch.log(k*true_class_noise))).mean()
  29. noise_class_loss = -torch.log( 1-F.sigmoid(noise_class_logits - torch.log(k*noise_class_noise))).mean()
  30. loss = true_class_loss+noise_class_loss
  31. print("nce loss is {:.4f}".format(loss))
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

圆咕噜咕噜

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表