变分扩散模型 ELBO 重构推导详解

打印 上一主题 下一主题

主题 1044|帖子 1044|积分 3132

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
变分扩散模型 ELBO 重构推导详解

在变分扩散模型(Variational Diffusion Model)中,证据下界(Evidence Lower Bound, ELBO)的形式通过优化正向和逆向分布的匹配来实现数据生成。初始 ELBO (变分扩散模型中的 Evidence Lower Bound (ELBO) 详解)存在采样复杂性,尤其是过渡块中需要联合分布 (                                              q                            φ                                  (                                   x                                       t                               −                               1                                            ,                                   x                                       t                               +                               1                                            ∣                                   x                            0                                  )                              q_φ(x_{t-1}, x_{t+1}|x_0)                  qφ​(xt−1​,xt+1​∣x0​) ) 的样本,这引发了重新设计的动机。后面提出了一种等价的 ELBO 形式,通过贝叶斯定理和条件调解简化了计算。本文将详细推导这一重构过程,解释这种转变,面向具备概率论和深度学习基础的读者。
参考:https://arxiv.org/pdf/2403.18103

初始 ELBO 的题目

原始 ELBO

原来定义的 ELBO 为:
                                                    ELBO                                           φ                                  ,                                  θ                                                 (                            x                            )                            =                                       E                                                        q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                 [                            log                            ⁡                                       p                               θ                                      (                                       x                               0                                      ∣                                       x                               1                                      )                            ]                            −                                       E                                                        q                                     φ                                              (                                               x                                                   T                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                            [                                           D                                               K                                     L                                                      (                                           q                                  φ                                          (                                           x                                  T                                          ∣                                           x                                               T                                     −                                     1                                                      )                               ∥                               p                               (                                           x                                  T                                          )                               )                               ]                                      −                                       ∑                                           t                                  =                                  1                                                      T                                  −                                  1                                                            E                                                        q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ,                                               x                                                   t                                        +                                        1                                                           ∣                                               x                                     0                                              )                                                            [                                           D                                               K                                     L                                                      (                                           q                                  φ                                          (                                           x                                  t                                          ∣                                           x                                               t                                     −                                     1                                                      )                               ∥                                           p                                  θ                                          (                                           x                                  t                                          ∣                                           x                                               t                                     +                                     1                                                      )                               )                               ]                                            \text{ELBO}_{φ,θ}(x) = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] - \mathbb{E}_{q_φ(x_{T-1}|x_0)} \left[ D_{KL}(q_φ(x_T|x_{T-1}) \| p(x_T)) \right] - \sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} \left[ D_{KL}(q_φ(x_t|x_{t-1}) \| p_θ(x_t|x_{t+1})) \right]                     ELBOφ,θ​(x)=Eqφ​(x1​∣x0​)​[logpθ​(x0​∣x1​)]−Eqφ​(xT−1​∣x0​)​[DKL​(qφ​(xT​∣xT−1​)∥p(xT​))]−t=1∑T−1​Eqφ​(xt−1​,xt+1​∣x0​)​[DKL​(qφ​(xt​∣xt−1​)∥pθ​(xt​∣xt+1​))]


  • 初始块:重构项 (                                                    E                                                        q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                 [                            log                            ⁡                                       p                               θ                                      (                                       x                               0                                      ∣                                       x                               1                                      )                            ]                                  \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)]                     Eqφ​(x1​∣x0​)​[logpθ​(x0​∣x1​)] )。
  • 最终块:先验匹配项 (                                         −                                       E                                                        q                                     φ                                              (                                               x                                                   T                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                 [                                       D                                           K                                  L                                                 (                                       q                               φ                                      (                                       x                               T                                      ∣                                       x                                           T                                  −                                  1                                                 )                            ∥                            p                            (                                       x                               T                                      )                            )                            ]                                  -\mathbb{E}_{q_φ(x_{T-1}|x_0)} [D_{KL}(q_φ(x_T|x_{T-1}) \| p(x_T))]                     −Eqφ​(xT−1​∣x0​)​[DKL​(qφ​(xT​∣xT−1​)∥p(xT​))] )。
  • 过渡块:一致性项 (                                         −                                       ∑                                           t                                  =                                  1                                                      T                                  −                                  1                                                            E                                                        q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ,                                               x                                                   t                                        +                                        1                                                           ∣                                               x                                     0                                              )                                                 [                                       D                                           K                                  L                                                 (                                       q                               φ                                      (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 )                            ∥                                       p                               θ                                      (                                       x                               t                                      ∣                                       x                                           t                                  +                                  1                                                 )                            )                            ]                                  -\sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} [D_{KL}(q_φ(x_t|x_{t-1}) \| p_θ(x_t|x_{t+1}))]                     −∑t=1T−1​Eqφ​(xt−1​,xt+1​∣x0​)​[DKL​(qφ​(xt​∣xt−1​)∥pθ​(xt​∣xt+1​))] )。
题目所在

过渡块需要从联合分布 (                                              q                            φ                                  (                                   x                                       t                               −                               1                                            ,                                   x                                       t                               +                               1                                            ∣                                   x                            0                                  )                              q_φ(x_{t-1}, x_{t+1}|x_0)                  qφ​(xt−1​,xt+1​∣x0​) ) 抽样,这涉及未来状态 (                                              x                                       t                               +                               1                                                 x_{t+1}                  xt+1​ ) 和过去状态 (                                              x                                       t                               −                               1                                                 x_{t-1}                  xt−1​ ) 的耦合。直接采样 (                                    (                                   x                                       t                               −                               1                                            ,                                   x                                       t                               +                               1                                            )                              (x_{t-1}, x_{t+1})                  (xt−1​,xt+1​) ) 复杂,因为 (                                              q                            φ                                  (                                   x                                       t                               +                               1                                            ∣                                   x                            0                                  )                              q_φ(x_{t+1}|x_0)                  qφ​(xt+1​∣x0​) ) 依赖多步正向过程,且正向 (                                              q                            φ                                  (                                   x                            t                                  ∣                                   x                                       t                               −                               1                                            )                              q_φ(x_t|x_{t-1})                  qφ​(xt​∣xt−1​) ) 和逆向 (                                              p                            θ                                  (                                   x                            t                                  ∣                                   x                                       t                               +                               1                                            )                              p_θ(x_t|x_{t+1})                  pθ​(xt​∣xt+1​) ) 方向相反,增加了计算负担。

重构动机与贝叶斯调解

一致性项的寻衅



  • (                                                    q                               φ                                      (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 )                                  q_φ(x_t|x_{t-1})                     qφ​(xt​∣xt−1​) ) 是正向过渡,(                                                    p                               θ                                      (                                       x                               t                                      ∣                                       x                                           t                                  +                                  1                                                 )                                  p_θ(x_t|x_{t+1})                     pθ​(xt​∣xt+1​) ) 是逆向过渡,两者方向相反,导致需要同时处理 (                                                    x                                           t                                  −                                  1                                                       x_{t-1}                     xt−1​ ) 和 (                                                    x                                           t                                  +                                  1                                                       x_{t+1}                     xt+1​ ) 的样本。
  • 目标是简化一致性查抄,避免“反向”依赖。
贝叶斯定理的引入

通过贝叶斯定理调解条件分布:
                                         q                            (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 )                            =                                                   q                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                  q                                  (                                               x                                     t                                              )                                                      q                                  (                                               x                                                   t                                        −                                        1                                                           )                                                       q(x_t|x_{t-1}) = \frac{q(x_{t-1}|x_t) q(x_t)}{q(x_{t-1})}                     q(xt​∣xt−1​)=q(xt−1​)q(xt−1​∣xt​)q(xt​)​
条件于 (                                              x                            0                                       x_0                  x0​ ):
                                         q                            (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 ,                                       x                               0                                      )                            =                                                   q                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                  q                                  (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                      q                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                       q(x_t|x_{t-1}, x_0) = \frac{q(x_{t-1}|x_t, x_0) q(x_t|x_0)}{q(x_{t-1}|x_0)}                     q(xt​∣xt−1​,x0​)=q(xt−1​∣x0​)q(xt−1​∣xt​,x0​)q(xt​∣x0​)​


  • 这一变更将正向 (                                         q                            (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 ,                                       x                               0                                      )                                  q(x_t|x_{t-1}, x_0)                     q(xt​∣xt−1​,x0​) ) 转化为逆向形式的 (                                         q                            (                                       x                                           t                                  −                                  1                                                 ∣                                       x                               t                                      ,                                       x                               0                                      )                                  q(x_{t-1}|x_t, x_0)                     q(xt−1​∣xt​,x0​) ),方向与 (                                                    p                               θ                                      (                                       x                                           t                                  −                                  1                                                 ∣                                       x                               t                                      )                                  p_θ(x_{t-1}|x_t)                     pθ​(xt−1​∣xt​) ) 一致。
  • (                                                    x                               0                                            x_0                     x0​ ) 的条件确保分布依赖初始状态,避免无限制采样。

重构 ELBO 的推导

步骤 1:从 Jensen 不等式开始

从之前的基础推导(变分扩散模型 ELBO 的推导过程详解)出发:
                                         log                            ⁡                            p                            (                            x                            )                            ≥                                       E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                                       0                                           :                                           T                                                                )                                                                         q                                        φ                                                  (                                                   x                                                       1                                           :                                           T                                                                ∣                                                   x                                        0                                                  )                                                      ]                                            \log p(x) \geq \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} \right]                     logp(x)≥Eqφ​(x1:T​∣x0​)​[logqφ​(x1:T​∣x0​)p(x0:T​)​]
代入联合分布:
                                         p                            (                                       x                                           0                                  :                                  T                                                 )                            =                            p                            (                                       x                               T                                      )                            p                            (                                       x                               0                                      ∣                                       x                               1                                      )                                       ∏                                           t                                  =                                  2                                          T                                      p                            (                                       x                                           t                                  −                                  1                                                 ∣                                       x                               t                                      )                                  p(x_{0:T}) = p(x_T) p(x_0|x_1) \prod_{t=2}^T p(x_{t-1}|x_t)                     p(x0:T​)=p(xT​)p(x0​∣x1​)t=2∏T​p(xt−1​∣xt​)
                                                    q                               φ                                      (                                       x                                           1                                  :                                  T                                                 ∣                                       x                               0                                      )                            =                                       q                               φ                                      (                                       x                               1                                      ∣                                       x                               0                                      )                                       ∏                                           t                                  =                                  2                                          T                                                 q                               φ                                      (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 ,                                       x                               0                                      )                                  q_φ(x_{1:T}|x_0) = q_φ(x_1|x_0) \prod_{t=2}^T q_φ(x_t|x_{t-1}, x_0)                     qφ​(x1:T​∣x0​)=qφ​(x1​∣x0​)t=2∏T​qφ​(xt​∣xt−1​,x0​)
(注意:这里 (                                              q                            φ                                  (                                   x                            t                                  ∣                                   x                                       t                               −                               1                                            ,                                   x                            0                                  )                              q_φ(x_t|x_{t-1}, x_0)                  qφ​(xt​∣xt−1​,x0​) ) 因马尔可夫性简化为 (                                              q                            φ                                  (                                   x                            t                                  ∣                                   x                                       t                               −                               1                                            )                              q_φ(x_t|x_{t-1})                  qφ​(xt​∣xt−1​)),但为一致性生存条件。)
步骤 2:睁开对数项

                                         log                            ⁡                                                   p                                  (                                               x                                                   0                                        :                                        T                                                           )                                                                   q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                 =                            log                            ⁡                                                   p                                  (                                               x                                     T                                              )                                  p                                  (                                               x                                     0                                              ∣                                               x                                     1                                              )                                               ∏                                                   t                                        =                                        2                                                  T                                              p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                               ∏                                                   t                                        =                                        2                                                  T                                                           q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     0                                              )                                                       \log \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} = \log \frac{p(x_T) p(x_0|x_1) \prod_{t=2}^T p(x_{t-1}|x_t)}{q_φ(x_1|x_0) \prod_{t=2}^T q_φ(x_t|x_{t-1}, x_0)}                     logqφ​(x1:T​∣x0​)p(x0:T​)​=logqφ​(x1​∣x0​)∏t=2T​qφ​(xt​∣xt−1​,x0​)p(xT​)p(x0​∣x1​)∏t=2T​p(xt−1​∣xt​)​
分离:
                                         =                            log                            ⁡                                                   p                                  (                                               x                                     T                                              )                                  p                                  (                                               x                                     0                                              ∣                                               x                                     1                                              )                                                                   q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                 +                            log                            ⁡                                                                ∏                                                   t                                        =                                        2                                                  T                                              p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   ∏                                                   t                                        =                                        2                                                  T                                                           q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     0                                              )                                                       = \log \frac{p(x_T) p(x_0|x_1)}{q_φ(x_1|x_0)} + \log \frac{\prod_{t=2}^T p(x_{t-1}|x_t)}{\prod_{t=2}^T q_φ(x_t|x_{t-1}, x_0)}                     =logqφ​(x1​∣x0​)p(xT​)p(x0​∣x1​)​+log∏t=2T​qφ​(xt​∣xt−1​,x0​)∏t=2T​p(xt−1​∣xt​)​
步骤 3:应用贝叶斯调解

对第二项,使用贝叶斯定理:
                                                                p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     0                                              )                                                 =                                                   p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                                                            q                                           φ                                                      (                                                       x                                           t                                                      ∣                                                       x                                           0                                                      )                                                                               q                                           φ                                                      (                                                       x                                                           t                                              −                                              1                                                                     ∣                                                       x                                           0                                                      )                                                                                \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0) \frac{q_φ(x_t|x_0)}{q_φ(x_{t-1}|x_0)}}                     qφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​=qφ​(xt−1​∣xt​,x0​)qφ​(xt−1​∣x0​)qφ​(xt​∣x0​)​p(xt−1​∣xt​)​
                                         =                                                                q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                               q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                                   q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                 ⋅                                                   p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                                       = \frac{q_φ(x_{t-1}|x_t, x_0) q_φ(x_t|x_0)}{q_φ(x_{t-1}|x_0)} \cdot \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)}                     =qφ​(xt−1​∣x0​)qφ​(xt−1​∣xt​,x0​)qφ​(xt​∣x0​)​⋅qφ​(xt−1​∣xt​,x0​)p(xt−1​∣xt​)​
整理乘积:
                                                    ∏                                           t                                  =                                  2                                          T                                                             p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     0                                              )                                                 =                                       ∏                                           t                                  =                                  2                                          T                                                             p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                                 ⋅                                                                q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                                   q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                       \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} \cdot \frac{q_φ(x_{t-1}|x_0)}{q_φ(x_t|x_0)}                     t=2∏T​qφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​=t=2∏T​qφ​(xt−1​∣xt​,x0​)p(xt−1​∣xt​)​⋅qφ​(xt​∣x0​)qφ​(xt−1​∣x0​)​
步骤 4:期望分离

                                                    E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                        T                                                  )                                     p                                     (                                                   x                                        0                                                  ∣                                                   x                                        1                                                  )                                                                         q                                        φ                                                  (                                                   x                                        1                                                  ∣                                                   x                                        0                                                  )                                                      +                               log                               ⁡                                           ∏                                               t                                     =                                     2                                              T                                                                   p                                     (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  )                                                                         q                                        φ                                                  (                                                   x                                        t                                                  ∣                                                   x                                                       t                                           −                                           1                                                                ,                                                   x                                        0                                                  )                                                      ]                                            \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T) p(x_0|x_1)}{q_φ(x_1|x_0)} + \log \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} \right]                     Eqφ​(x1:T​∣x0​)​[logqφ​(x1​∣x0​)p(xT​)p(x0​∣x1​)​+logt=2∏T​qφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​]


  • 第一项
                                                    E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                        T                                                  )                                     p                                     (                                                   x                                        0                                                  ∣                                                   x                                        1                                                  )                                                                         q                                        φ                                                  (                                                   x                                        1                                                  ∣                                                   x                                        0                                                  )                                                      ]                                            \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T) p(x_0|x_1)}{q_φ(x_1|x_0)} \right]                     Eqφ​(x1:T​∣x0​)​[logqφ​(x1​∣x0​)p(xT​)p(x0​∣x1​)​]
                                         =                                       E                                                        q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                 [                            log                            ⁡                                       p                               θ                                      (                                       x                               0                                      ∣                                       x                               1                                      )                            ]                            +                                       E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                        T                                                  )                                                                         q                                        φ                                                  (                                                   x                                        T                                                  ∣                                                   x                                        0                                                  )                                                      ]                                            = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] + \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_0)} \right]                     =Eqφ​(x1​∣x0​)​[logpθ​(x0​∣x1​)]+Eqφ​(x1:T​∣x0​)​[logqφ​(xT​∣x0​)p(xT​)​]


  • 第二项
                                                    E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                           ∏                                               t                                     =                                     2                                              T                                                                   p                                     (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  )                                                                         q                                        φ                                                  (                                                   x                                        t                                                  ∣                                                   x                                                       t                                           −                                           1                                                                ,                                                   x                                        0                                                  )                                                      ]                                      =                                       ∑                                           t                                  =                                  2                                          T                                                 E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  )                                                                         q                                        φ                                                  (                                                   x                                        t                                                  ∣                                                   x                                                       t                                           −                                           1                                                                ,                                                   x                                        0                                                  )                                                      ]                                            \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} \right] = \sum_{t=2}^T \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} \right]                     Eqφ​(x1:T​∣x0​)​[logt=2∏T​qφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​]=t=2∑T​Eqφ​(x1:T​∣x0​)​[logqφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​]
使用贝叶斯调解:
                                                                p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     0                                              )                                                 =                                                   p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                               q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                                   q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                               q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                       \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \frac{p(x_{t-1}|x_t) q_φ(x_{t-1}|x_0)}{q_φ(x_{t-1}|x_t, x_0) q_φ(x_t|x_0)}                     qφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​=qφ​(xt−1​∣xt​,x0​)qφ​(xt​∣x0​)p(xt−1​∣xt​)qφ​(xt−1​∣x0​)​
                                         log                            ⁡                                                   p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     0                                              )                                                 =                            log                            ⁡                                                   p                                  (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              )                                                                   q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     t                                              ,                                               x                                     0                                              )                                                 +                            log                            ⁡                                                                q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ∣                                               x                                     0                                              )                                                                   q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                       \log \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \log \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} + \log \frac{q_φ(x_{t-1}|x_0)}{q_φ(x_t|x_0)}                     logqφ​(xt​∣xt−1​,x0​)p(xt−1​∣xt​)​=logqφ​(xt−1​∣xt​,x0​)p(xt−1​∣xt​)​+logqφ​(xt​∣x0​)qφ​(xt−1​∣x0​)​
步骤 5:简化期望



  • 重构项
                                                    E                                                        q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                 [                            log                            ⁡                                       p                               θ                                      (                                       x                               0                                      ∣                                       x                               1                                      )                            ]                                  \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)]                     Eqφ​(x1​∣x0​)​[logpθ​(x0​∣x1​)]


  • 先验匹配项
                                                    E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                        T                                                  )                                                                         q                                        φ                                                  (                                                   x                                        T                                                  ∣                                                   x                                        0                                                  )                                                      ]                                      =                            −                                       D                                           K                                  L                                                 (                                       q                               φ                                      (                                       x                               T                                      ∣                                       x                               0                                      )                            ∥                            p                            (                                       x                               T                                      )                            )                                  \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_0)} \right] = -D_{KL}(q_φ(x_T|x_0) \| p(x_T))                     Eqφ​(x1:T​∣x0​)​[logqφ​(xT​∣x0​)p(xT​)​]=−DKL​(qφ​(xT​∣x0​)∥p(xT​))


  • 一致性项
                                                    ∑                                           t                                  =                                  2                                          T                                                 E                                                        q                                     φ                                              (                                               x                                                   1                                        :                                        T                                                           ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  )                                                                         q                                        φ                                                  (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  ,                                                   x                                        0                                                  )                                                      +                               log                               ⁡                                                                      q                                        φ                                                  (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        0                                                  )                                                                         q                                        φ                                                  (                                                   x                                        t                                                  ∣                                                   x                                        0                                                  )                                                      ]                                            \sum_{t=2}^T \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} + \log \frac{q_φ(x_{t-1}|x_0)}{q_φ(x_t|x_0)} \right]                     t=2∑T​Eqφ​(x1:T​∣x0​)​[logqφ​(xt−1​∣xt​,x0​)p(xt−1​∣xt​)​+logqφ​(xt​∣x0​)qφ​(xt−1​∣x0​)​]
第二项的和为:
                                         log                            ⁡                                                                q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                                   q                                     φ                                              (                                               x                                     T                                              ∣                                               x                                     0                                              )                                                 =                            log                            ⁡                                       q                               φ                                      (                                       x                               1                                      ∣                                       x                               0                                      )                            −                            log                            ⁡                                       q                               φ                                      (                                       x                               T                                      ∣                                       x                               0                                      )                                  \log \frac{q_φ(x_1|x_0)}{q_φ(x_T|x_0)} = \log q_φ(x_1|x_0) - \log q_φ(x_T|x_0)                     logqφ​(xT​∣x0​)qφ​(x1​∣x0​)​=logqφ​(x1​∣x0​)−logqφ​(xT​∣x0​)
但重点是第一项:
                                                    E                                                        q                                     φ                                              (                                               x                                                   t                                        −                                        1                                                           ,                                               x                                     t                                              ∣                                               x                                     0                                              )                                                            [                               log                               ⁡                                                        p                                     (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  )                                                                         q                                        φ                                                  (                                                   x                                                       t                                           −                                           1                                                                ∣                                                   x                                        t                                                  ,                                                   x                                        0                                                  )                                                      ]                                      =                            −                                       E                                                        q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                            [                                           D                                               K                                     L                                                      (                                           q                                  φ                                          (                                           x                                               t                                     −                                     1                                                      ∣                                           x                                  t                                          ,                                           x                                  0                                          )                               ∥                                           p                                  θ                                          (                                           x                                               t                                     −                                     1                                                      ∣                                           x                                  t                                          )                               )                               ]                                            \mathbb{E}_{q_φ(x_{t-1}, x_t|x_0)} \left[ \log \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} \right] = -\mathbb{E}_{q_φ(x_t|x_0)} \left[ D_{KL}(q_φ(x_{t-1}|x_t, x_0) \| p_θ(x_{t-1}|x_t)) \right]                     Eqφ​(xt−1​,xt​∣x0​)​[logqφ​(xt−1​∣xt​,x0​)p(xt−1​∣xt​)​]=−Eqφ​(xt​∣x0​)​[DKL​(qφ​(xt−1​∣xt​,x0​)∥pθ​(xt−1​∣xt​))]
步骤 6:范围调解

从 (                                    t                         =                         2                              t=2                  t=2 ) 到 (                                    t                         =                         T                              t=T                  t=T ) 对应 (                                              x                                       t                               −                               1                                                 x_{t-1}                  xt−1​ ) 从 (                                              x                            1                                       x_1                  x1​ ) 到 (                                              x                                       T                               −                               1                                                 x_{T-1}                  xT−1​ ),与过渡块 (                                    t                         =                         1                              t=1                  t=1 ) 到 (                                    T                         −                         1                              T-1                  T−1 ) 一致,调解索引。
最终 ELBO

                                                    ELBO                                           φ                                  ,                                  θ                                                 (                            x                            )                            =                                       E                                                        q                                     φ                                              (                                               x                                     1                                              ∣                                               x                                     0                                              )                                                 [                            log                            ⁡                                       p                               θ                                      (                                       x                               0                                      ∣                                       x                               1                                      )                            ]                            −                                       D                                           K                                  L                                                 (                                       q                               φ                                      (                                       x                               T                                      ∣                                       x                               0                                      )                            ∥                            p                            (                                       x                               T                                      )                            )                            −                                       ∑                                           t                                  =                                  2                                          T                                                 E                                                        q                                     φ                                              (                                               x                                     t                                              ∣                                               x                                     0                                              )                                                            [                                           D                                               K                                     L                                                      (                                           q                                  φ                                          (                                           x                                               t                                     −                                     1                                                      ∣                                           x                                  t                                          ,                                           x                                  0                                          )                               ∥                                           p                                  θ                                          (                                           x                                               t                                     −                                     1                                                      ∣                                           x                                  t                                          )                               )                               ]                                            \text{ELBO}_{φ,θ}(x) = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] - D_{KL}(q_φ(x_T|x_0) \| p(x_T)) - \sum_{t=2}^T \mathbb{E}_{q_φ(x_t|x_0)} \left[ D_{KL}(q_φ(x_{t-1}|x_t, x_0) \| p_θ(x_{t-1}|x_t)) \right]                     ELBOφ,θ​(x)=Eqφ​(x1​∣x0​)​[logpθ​(x0​∣x1​)]−DKL​(qφ​(xT​∣x0​)∥p(xT​))−t=2∑T​Eqφ​(xt​∣x0​)​[DKL​(qφ​(xt−1​∣xt​,x0​)∥pθ​(xt−1​∣xt​))]

推导总结



  • 贝叶斯定理将 (                                                    q                               φ                                      (                                       x                               t                                      ∣                                       x                                           t                                  −                                  1                                                 ,                                       x                               0                                      )                                  q_φ(x_t|x_{t-1}, x_0)                     qφ​(xt​∣xt−1​,x0​) ) 转化为 (                                                    q                               φ                                      (                                       x                                           t                                  −                                  1                                                 ∣                                       x                               t                                      ,                                       x                               0                                      )                                  q_φ(x_{t-1}|x_t, x_0)                     qφ​(xt−1​∣xt​,x0​) ),与 (                                                    p                               θ                                      (                                       x                                           t                                  −                                  1                                                 ∣                                       x                               t                                      )                                  p_θ(x_{t-1}|x_t)                     pθ​(xt−1​∣xt​) ) 方向一致。
  • 期望从联合分布简化为单变量,消除了 (                                                    x                                           t                                  +                                  1                                                       x_{t+1}                     xt+1​ ) 的依赖。
  • 新的 ELBO 保持优化目标,简化了采样复杂性。

代码实现片断(伪代码)

  1. def elbo_loss_new(x0, model, T, alpha_schedule):
  2.     elbo = 0.0
  3.     x1 = forward_transition(x0, alpha_schedule[1])
  4.     elbo += torch.mean(model.log_prob_x0_given_x1(x0, x1))  # Reconstruction
  5.    
  6.     xT = forward_multi_step(x0, alpha_schedule)
  7.     kl_prior = kl_divergence(xT, torch.zeros_like(xT), torch.ones_like(xT))
  8.     elbo -= kl_prior  # Prior matching
  9.    
  10.     for t in range(2, T + 1):
  11.         xt = forward_step(x0, t, alpha_schedule)
  12.         xt_minus_1 = forward_step(x0, t - 1, alpha_schedule)
  13.         kl_cons = kl_divergence(xt_minus_1, model.reverse_mean(xt, t), model.reverse_cov(xt, t))
  14.         elbo -= torch.mean(kl_cons)  # Consistency
  15.    
  16.     return elbo
复制代码

总结

重构后的 ELBO 通过贝叶斯调解消除了联合采样的复杂性,保持了模型的优化能力。这一设计体现了扩散模型的灵活性,为高效训练提供了大概。
渴望这篇推导帮助你明白!
跋文

2025年3月5日18点17分于上海,在grok 3大模型辅助下完成。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

杀鸡焉用牛刀

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表