扩散模型(DDPM)推导与整理

目录

前言

先说一些废话。其实 2022 年 12 月的时候我们所谓的“圈内人”已经用上了 ChatGpt,而从那时候开始我们都已经觉得这东西已经有无限的潜力,直到 2023 年 2 月,用一句流行话语来说就是:“ChatGpt 成功出圈了”,真正地火爆全网,所有行业都开始关注这个可能在未来取代某些职业的技术。

上面的一些废话主要也是想引出 DALL-E 2、Stable Diffusion 和 Midjourney 这一系列 Text-to-Image 工具的大火,我想说的是如果一项工具/技术只是在它所属的领域广为人知那么可能只是该项技术确实引起了业内人士的广泛关注,但是如果该工具/技术在全社会(各行各业)中广为人知,那么证明它将引起社会的某些改变,且不说这些改变是好是坏,但能否得到关注、得到”流量“已经是当今社会一项技术证明其自身价值的最好方法。

今天的主角是 DALL-E 2、Stable Diffusion 和 Midjourney 等工具背后的一个模型(DDPM: Denoising Diffusion Probabilistic Models,去噪扩散概率模型),当然这些工具并不是直接使用这个模型就能达到这么好的效果,或多或少对 DDPM 进行了一些改进,在下面讲到 DDPM 的一些弊端的时候,我们也会讲到。

DDPM 属于生成类模型,当前主流的有四大生成模型:GAN(生成对抗模型)、VAE(变微分自动编码器)、Flow-based(流模型)以及今天要说的 Diffusion Model(扩散模型)。而现在扩散模型是当前深度生成模型中的新 SOTA,不仅在图片生成任务中超越了原 SOTA:GAN,并且在多个领域都表现的十分出色。

Diffusion Model 的概念有点大,本笔记主要讲的是 DDPM(Denoising Diffusion Probabilistic Models),去噪扩散概率模型。该模型在《Denoising Diffusion Probabilistic Models》中被提出,而这篇 2020 年的论文中已经将 Diffusion Model 完成了应用任务。但是 Diffusion Model 首次出现实际上在 2015 年的《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》 ,“使用非平衡热力学进行深度无监督学习”这篇论文中,而这篇论文的作者就已经提出了 Diffusion Model 的整体框架并且完成了数学公式的推导,后面陆续有各类 Diffusion 模型提出。但是直到 2020 年 DDPM 的提出才有真正应用。

《Denoising Diffusion Probabilistic Models》 主要说的是可以使用统计力学的摄动方法,从根本上重新创建一个从随机噪音开始的特定分布。也有一个反向扩散的方法可以将噪音回收成原始分布。

下面有关于 DDPM(Denoising Diffusion Probabilistic Models 下文简称 ) 的笔记涉及到的公式比较多,如有不适可以马上关闭网页,毕竟我也是不适了很久。这里感谢 Probabilistic Diffusion Model 概率扩散模型理论与完整PyTorch代码详细解读 对这篇文章的详细解读,给予了我极大的帮助。

背景知识

我必须承认,DDPM 的整体思路还是比较简单,但是具体到整体公式的推导,比如说扩散过程加噪的简便方式数学推导和逆扩散过程的损失函数数学推导还是涉及到许多数学知识,所以有必要在这里先把一些涉及到的数学知识再回顾一下,当然也是本人数学基础不太牢固,如果看了一眼大概能知道的,就可以略过本节。

条件概率的一般形式

  • $P(A, B, C)=P(C \mid B, A) P(B, A)=P(C \mid B, A) P(B \mid A) P(A)$

    说明:A、B、C 的联合概率分布可展开为上面的形式,具体其实就是乘法公式不断展开。这里联合概率写法可能各有不同,但参考联合概率的定义:联合概率表示两个事件共同发生的概率,A 与 B 的联合概率表示为 $P(A \cap B) $ 或者 $P(A, B)$

  • $P(B, C \mid A)=P(B \mid A) P(C \mid A, B)$

    说明:A 发生情况下,B, C 发生的概率。具体其实利用条件概率的公式进行展开,然后利用上面的公式进行化简,具体为:

image.png

基于马尔科夫假设的条件概率

马尔科夫假设,简单来说就是:当前时刻的概率分布只与上一时刻有关,与过去更远的时刻的概率分布无关。那么如果满足马尔科夫链关系:A→B→C,上面所提到的两个公式就可以简化为:

  • $P(A, B, C)=P(C \mid B, A) P(B, A)=P(C \mid B) P(B \mid A) P(A)$ ,因为 B 是 C 的上一时刻,而 A 已经与 C 无关,所以可以简化为此。
  • $P(B, C \mid A)=P(B \mid A) P(C \mid B)$,同理。

高斯分布的 KL 散度公式

对于两个单一变量的高斯分布 p 和 q 而言,它们的 KL 散度为: $K L(p, q)=\log \frac{\sigma_{2}}{\sigma_{1}}+\frac{\sigma_{1}^{2}+\left(\mu_{1}-\mu_{2}\right)^{2}}{2 \sigma_{2}^{2}}-\frac{1}{2}$

参数重整化(重参数化技巧)

来源于 VAE 中的一个技巧。

若希望从高斯分布 $N(\mu, \sigma^{2})$ 中采样,可以先从标准分布 $N(0, 1)$ 采样出 $z$,再得到 $\sigma * z+\mu$ 。这样做的好处是将随机性转移到了 $z$ 这个常量上,而 $\sigma$$\mu$ 则当做仿射变换网络的一部分。

单层 VAE 的原理公式与置信下界

多层 VAE 的原理公式与置信下界

多层 VAE 训练中,先是有 $x$ 推理出 $z$,然后再由 $z$ 去预测 $x$,这个过程其实和 Duffusion Model 是类似的,那么有理由相信 Duffusion Model 的目标函数和多层 VAE 的目标函数是类似的。

Diffusion Model

图片来自 DDPM 论文

Duffusion Model 分为两个过程:

  1. 扩散过程(正向过程):从目标分布($X_{0}$)到噪声分布($X_{T}$),这是一个熵增的过程(即从有序变为无序),即从原始分布中不断地去加高斯噪声直至最后变成一个各项对立的高斯分布;直白点说就是往输入数据(数据集)中不断加噪声(此噪声服从高斯分布),而这个加噪声的过程其实就是构建标签的过程
  2. 逆扩散过程(反向过程):基于一个噪声分布能够把目标分布推出来,然后从这个目标分布中去采样新的样本,从而生成新的图像。直白点来说去噪、不断复原的过程,那其实也是生成目标的过程。

Duffision Model 要干的事情是:现在有一堆的目标分布(比如上图中的人像照片),我们希望能够把逆扩散过程(比如 $X_{T}$$X_{0}$)的原理或者公式给预测出来,然后我们就可以随机地生成一个噪声分布从而去预测新的这个人像照片。

其中上图中,$q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)$ 代表扩散过程中的条件概率分布;而 $p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)$ 代表逆扩散过程中的条件概率分布。

扩散过程

所谓前向过程,也就是扩散过程,用具体的例子来讲就是往图片不断加上噪声的过程,这样过程的目的是为了构建训练样本。

image.png

迭代计算目标分布

我们给定一个初始数据分布(数据集、真实图片)$x_{0} \sim q(x)$,在这样一个前向加噪过程中,逐步往初始数据分布中添加高斯噪声,累计加噪 $T$ 次,得到一系列带噪声图片 $x_{1}, x_{2}, ... x_{T}$ ,即下图中的 $q$ 过程。加噪过程中,该高斯噪声的标准差是以固定值 $\beta_{t}$ 确定的,均值是以固定值 $\beta_{t}$ 和当前 $t$ 时刻的数据 $x_{t}$ 决定的,也是确定的,这里说的确定并非是一成不变的数值(会随着时间的变化而不同),而是指并非是从网络中预测出来的。并且由于前向过程每个 $t$ 时刻只与 $t-1$ 时刻有关,所以这个过程可以看做是一个马尔科夫链过程。其中加噪过程数学公式如下:

  • $q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathbf{I}\right)$ :这个公式指在 $q$ 过程(前向过程)中,比如说从 $x_{1}$ 去预测 $x_{2}$ ,则服从一个条件概率分布,并且该条件概率分布是高斯分布,均值是 $\sqrt{1-\beta_{t}} \mathbf{x}_{t-1}$,方差是 $\beta_{t}$,这也说明了每次加噪的高斯分布只由 $\beta_{t}$$x_{t-1}$ 确定,是一个固定值而不是一个学习过程。那么就可以总结为:我有了 $x_{0}$,并且确定每一步的固定值 $\beta_{1}, \beta_{2}, \ldots, \beta_{t}$,就可以一步一步地迭代地去推出在 $T$ 时刻内的 $x_{1}, x_{2}, \ldots, x_{T}$,注意我这里说的是一步一步地迭代才能得到。
  • $q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)=\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)$:这个公式指给定 $x_{0}$$x_{1}$$x_{T}$ 的联合概率分布就是上面公式每一步迭代的结果相乘。
  • 论文中的 $\beta_{t}$ 的范围在 $(0,1)$ 之间的小数,并且随着 $t$ 的不断增大, $\beta_{t}$是逐渐增大的,这一点类比学习率的性质(只不过学习率是越来越小)

利用 $x_{t}$ 服从 $q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathbf{I}\right)$ 的分布,在通过重参数化技巧,就可以算出 $x_{t}$,具体为:从一个正态分布中生成一个 $z$,那么 $x_{t} = z \times \sqrt{\beta_{t}} + \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}$,经过这样的计算总能一步步迭代得到 $x_{T}$ 的采样值。

不迭代计算目标分布

刚才一再强调,只能使用迭代一步步算出 $x_{T}$ ,那有没有办法不迭代,不一步步地也能算出 $x_{T}$,得到最后 $T$ 时刻的数据分布了呢?是可以的😎,也就是说任意时刻的 $x_{t}$ 是可以直接从 $x_{0}$ 和固定值 $\left\{\beta_{T} \in(0,1)\right\}_{t=1}^{T}$ 计算得到,具体推导如下:

我们为了方便,定义如下:$\alpha_{t}=1-\beta_{t}$$\bar{\alpha}_{t}=\prod_{i=1}^{T} \alpha_{i}$ ,然后根据重参数技巧,从正态分布中生成 $z_{t-1},z_{t-2}, \ldots,$,那么有:

$\begin{aligned} \mathbf{x}_{t} & =\sqrt{\alpha_{t}} \mathbf{x}_{t-1}+\sqrt{1-\alpha_{t}} \mathbf{z}_{t-1} \quad ; \text { where } \mathbf{z}_{t-1}, \mathbf{z}_{t-2}, \cdots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ & =\sqrt{\alpha_{t}}\left(\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \mathbf{z}_{t-2}\right)+\sqrt{1-\alpha_{t}} \mathbf{z}_{t-1} \quad ; \text { 将 } \mathbf{x}_{t-1} \text { 按上述式子进行代换 } \\ & =\sqrt{\alpha_{t} \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{\alpha_{t}-\alpha_{t} \alpha_{t-1}} \mathbf{z}_{t-2}+\sqrt{1-\alpha_{t}} \mathbf{z}_{t-1} \\ & =\sqrt{\alpha_{t} \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_{t} \alpha_{t-1}} \overline{\mathbf{z}}_{t-2} \quad ; \text { where } \overline{\mathbf{z}}_{t-2} \text { merges two Gaussians }{ }^{*} . \\ & =\sqrt{\alpha_{t} \alpha_{t-1}} \left(\sqrt{\alpha_{t-2}} \mathbf{x}_{t-3}+\sqrt{1-\alpha_{t-2}} \mathbf{z}_{t-3}\right)+\sqrt{1-\alpha_{t} \alpha_{t-1}} \overline{\mathbf{z}}_{t-2} \quad ; \text { 将 } \mathbf{x}_{t-2} \text { 按上述式子进行代换 } \\ & = \sqrt{\alpha_{t} \alpha_{t-1} \alpha_{t-2}} \mathbf{x}_{t-3}+\sqrt{\alpha_{t}\alpha_{t-1}-\alpha_{t}\alpha_{t-1}\alpha_{t-2} } \mathbf{z}_{t-3}+\sqrt{1-\alpha_{t} \alpha_{t-1}} \overline{\mathbf{z}}_{t-2} \\ & = \sqrt{\alpha_{t} \alpha_{t-1} \alpha_{t-2}} \mathbf{x}_{t-3}+\sqrt{1-\alpha_{t} \alpha_{t-1}\alpha_{t-2}} \overline{\mathbf{z}}_{t-3} \quad ; \text { where } \overline{\mathbf{z}}_{t-2} \text { merges two Gaussians }{ }^{*} . \\ & =\cdots \\ & =\sqrt{\alpha_{t} \alpha_{t-1} \alpha_{t-2} \cdots \alpha_{1}} \mathbf{x}_{0}+\sqrt{1-\alpha_{t} \alpha_{t-1} \alpha_{t-2} \cdots \alpha_{1}} \mathbf{z} \quad ;\text { where } \mathbf{z} \text { is also } \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ & =\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \mathbf{z} \\ q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right) & =\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right) \end{aligned}$

来解释一下:

  • 第一步利用参数重整化技巧以及上面的定义($\alpha$$\beta$ 的关系)进行变量替换,得到第一行的等式;
  • 第二步将 $\mathbf{x}_{t-1}$进行代换;
  • 第三步进行整理;
  • 第四步也就是最妙的地方,怎么从第三步变为第四步,这里用到了正态分布的叠加性以及再次用到了重参数化技巧,下面具体说明。

正态分布的叠加性是指相互独立的正态分布的线性组合仍然服从正态分布。比如给定两个独立的正态分布:$\mathrm{X}_{1} \sim \mathrm{N}\left(\mu_{1}, \sigma_{1}^{2}\right)$$\mathrm{X}_{2} \sim \mathrm{N}\left(\mu_{2}, \sigma_{2}^{2}\right)$,且 $a, b$ 均为实数,则有如下叠加算式:$\mathrm{aX_{1}}+\mathrm{bX_{2}} \sim \mathrm{N}\left(\mathrm{a} \mu_{1}+\mathrm{b} \mu_{2}, \mathrm{a}^{2} \sigma_{1}^{2}+\mathrm{b}^{2} \sigma_{2}^{2}\right)$

所以,由于第三步中的 $\sqrt{\alpha_{t}-\alpha_{t}\alpha_{t-1}}\mathbf{z}_{t-2}$$\sqrt{1-\alpha_{t}} \mathbf{z}_{t-1}$ 中的 $\mathbf{z}_{t-2}$$\mathbf{z}_{t-1}$ 由之前的重参数技巧知道这都是服从 $\mathcal{N}(\mathbf{0}, \mathbf{I})$ 分布的,均值为 0,方差为 1,并且相互独立,则叠加之后,方差就是前面的系数平方和,也就是 $(\sqrt{\alpha_{t}-\alpha_{t}\alpha_{t-1}})^{2} + (\sqrt{1-\alpha_{t}})^{2}=1-\alpha_{t}\alpha_{t-1}$。那么也就是说第三步的后两项和服从 $\mathcal{N}(\mathbf{0}, 1-\alpha_{t}\alpha_{t-1})$ 的分布,然后!又根据重参数技巧,随机生成一个 $\overline{\mathbf{z}}_{t-2}$ 服从正态分布,那么后两项和就可以表示为 $\sqrt{1-\alpha_{t}\alpha_{t-1}} \times \overline{\mathbf{z}}_{t-2} + \text{均值}(=0)$,就可以化简成第四行等式。

同理,第五行将 $\mathbf{x}_{t-2}$进行代换,第六行等式进行化简,接着按照正态分布叠加性和重参数化技巧进行得到第七行等式。和上面的过程一模一样。接下来就是不断地替换,不断地化简、进行正态分布的叠加性和重参数化操作,得到倒数第二行,最后由上述定义:$\bar{\alpha}_{t}=\prod_{i=1}^{T} \alpha_{i}$,得到最后一行等式。

所以由上面的推导可知,从 $\mathbf{x}_{0}$ 可以求得 $\mathbf{x}_{t}$,等式如上。并且这里又由重参数技巧反推得知:$q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right)$,也就是说知道了 $\mathbf{x}_{0}$$\beta{t}$(上面定义过 $\alpha_{t}$$\beta{t}$ 的关系),并且再从标准正态分布中采样出一个 $z$ 就无需迭代也能采样出 $\mathbf{x}_{t}$ ,并且随着 $t$ 时刻不断增长,$\beta{t}$ 不断变大,论文中是 0.0001~0.002,则$\alpha{t}$ 不断减少,也可以看出前向扩散过程越往后,噪声影响鹅权重越来越大。而当 t 越来越大,直到 $t \to infty$$x_t$ 等同于各向同性的高斯分布,即各个方向方差都一样的多维高斯分布。

扩散过程成功推导完成,当然不迭代进行目标分布在代码上也能做到但是代码执行过程中也是要不断地进行目标迭代,相比之下推导出由 $x_{0}$ 可以得到任意时刻的 $x_{t}$ 的数学式子,不仅在推理上更好理解,在写代码的时候也更加方便。

逆扩散过程

后验条件高斯分布均值计算

相比扩散过程,逆扩散过程才是 DDPM 的核心所在。其实上面扩散过程不迭代得到目标分布的推导并不算复杂,可是接下来就是知识与耐心的考验。为什么说逆扩散过程才是核心所在?这里我借用苏剑林大佬 生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼 的思想,刚才扩散过程可以看做一栋完整的高楼拆楼的过程,一步步地添加噪声其实就是去对这栋大楼进行摧毁的过程,每一步怎么摧毁(摧毁的程度)其实就是所加的噪声,形象地比喻这个过程,你就能更加理解为什么扩散的过程不需要学习,所有参数的都是已知,包括对噪声干预的把控,都是可控的。

逆扩散过程即从一个随机噪声开始,逐步还原成不带噪声的实际图片。但是拆楼容易建楼难,热力学第二定律告诉我们,宇宙的规律是熵增,是从有序到无序的过程,最终达到拉普拉斯方程描述的那种均衡;而逆过程是从噪声中恢复出清晰的信号,这是妥妥的熵减过程,这也是为什么 DDPM 需要耗费大量电力的原因,我们为了要达到熵减,不停地使用外力做功,才能达到如此完美的效果。

如果我们想要计算 $q\left(x_{t-1} \mid x_{t}\right)$,需要知道全部的数据集,

image.png

如果我们能够逐步得到逆转后的分布 $q\left(x_{t-1} \mid x_{t}\right)$ ,就可以从完全的标准高斯分布 $x_{T} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}$ 还原出原图分布 $x_0$ ,在这篇论文 中提到如果 $q\left(x_{t} \mid x_{t-1}\right)$ 满足高斯分布且 $\beta_{t}$ 足够小,$q\left(x_{t-1} \mid x_{t}\right)$ 仍然是一个高斯分布。但是很遗憾我们无法简单推断 $q\left(x_{t-1} \mid x_{t}\right)$ ,我们没有整个过程的数据集或者说没有这么一对对逆推的数据对,因为我们需要使用深度学习模型(参数为 $\theta$ ,目前主流框架是 U-Net + Attention)去预测这样的一个逆向分布 $p{\theta}$(这个过程有点类似 VAE):

$\begin{aligned} p_{\theta}\left(x_{t-1} \mid x_{t}\right) & =\mathcal{N}\left(x_{t-1} ; \mu_{\theta}\left(x_{t}, t\right), \Sigma_{\theta}\left(x_{t}, t\right)\right) \quad (1) \\ p_{\theta}\left(X_{0: T}\right) & =p\left(x_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(x_{t-1} \mid x_{t}\right) \quad (2) \\ \end{aligned}$

虽然我们无法得到逆转后的分布 $q\left(x_{t-1} \mid x_{t}\right)$,但是如果知道 $x_0$,是可以通过贝叶斯公式得到 $q\left(x_{t-1} \mid x_{t}, x_{0} \right)$,而 $q\left(x_{t-1} \mid x_{t}, x_{0}\right)=\mathcal{N}\left(x_{t-1} ; \tilde{\mu}\left(x_{t}, x_{0}\right), \tilde{\beta}_{t} \mathbf{I}\right)$ 。其中很好地运用了贝叶斯定理的“逆概”作用,从而利用已知条件进行转化为:具体公式推导如下:

$\begin{aligned} q\left(x_{t-1} \mid x_{t}, x_{0}\right) & =q\left(x_{t} \mid x_{t-1}, x_{0}\right) \frac{q\left(x_{t-1} \mid x_{0}\right)}{q\left(x_{t} \mid x_{0}\right)} \\ & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_{t}-\sqrt{\alpha_{t}} x_{t-1}\right)^{2}}{\beta_{t}}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_{0}\right)^{2}}{1-\bar{a}_{t-1}}-\frac{\left(x_{t}-\sqrt{\bar{\alpha}_{t}} x_{0}\right)^{2}}{1-\bar{a}_{t}}\right)\right) \\ & =\exp \left(-\frac{1}{2}(\underbrace{\left(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) x_{t-1}^{2}}_{x_{t-1} \text { 方差 }}-\underbrace{\left(\frac{2 \sqrt{\alpha_{t}}}{\beta_{t}} x_{t}+\frac{2 \sqrt{\bar{a}_{t-1}}}{1-\bar{\alpha}_{t-1}} x_{0}\right) x_{t-1}}_{x_{t-1} \text { 均值 }}+\underbrace{C\left(x_{t}, x_{0}\right)}_{\text {与 } x_{t-1} \text { 无关 }})\right) . \end{aligned}$

解释一下:

  • 第一步利用贝叶斯公式进行“逆概”计算,把未知转化为已知条件进行求解,转换之后可以求解的原因是我们是知道:$q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathbf{I}\right)$$q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right)$ 那么自然 $q\left(x_{t-1} \mid x_{0}\right)$ 也能写出表达,即这一步可以巧妙地将逆向过程变成前向过程从而使这一条路可以走得通;
  • 第二步分别写出其对应的高斯概率密度函数,并且只取各对应的指数部分;
  • 第三步整理成 $x_{t-1}$ 的高斯分布概率密度函数形式并且利用二次函数的顶点式公式整理(高斯分布概率密度函数形式为:$f(x)=\frac{1}{\sqrt{2 \pi} \sigma} e^{-\frac{(x-\mu)^{2}}{2 \sigma^{2}}}$

稍加整理就可以得到 $q\left(x_{t-1} \mid x_{t}, x_{0}\right)=\mathcal{N}\left(x_{t-1} ; \tilde{\mu}\left(x_{t}, x_{0}\right), \tilde{\beta}_{t} \mathbf{I}\right)$ 中的均值 $\tilde{\mu}\left(x_{t}, x_{0}\right)$ 和方差 $\tilde{\beta}_{t}$

$\begin{aligned} \frac{1}{\sigma^{2}} & =\frac{1}{\tilde{\beta}_{t}}=\left(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) ; \quad \tilde{\beta}_{t}=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}} \cdot \beta_{t} \\ \frac{2 \mu}{\sigma^{2}} & =\frac{2 \tilde{\mu}_{t}\left(x_{t}, x_{0}\right)}{\tilde{\beta}_{t}}=\left(\frac{2 \sqrt{\alpha_{t}}}{\beta_{t}} x_{t}+\frac{2 \sqrt{\bar{a}_{t}}}{1-\bar{\alpha}_{t}} x_{0}\right) ; \quad \tilde{\mu}_{t}\left(x_{t}, x_{0}\right)=\frac{\sqrt{a_{t}}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_{t}} x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{1-\bar{\alpha}_{t}} x_{0} . \end{aligned}$

又根据扩散过程求出的不迭代求出模板分布中 $x_t$$x_{0}$ 的关系表达式:$x_t = \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \mathbf{z}_{t}$,把 x_0 分离出来得到:$x_{0}=\frac{1}{\sqrt{\bar{\alpha }_{t}}}\left(x_{t}-\sqrt{1-\bar{\alpha }_{t}} \mathbf{z}_{t}\right)$ ,即为 $x_0$ 的表达式,代入到 $q\left(x_{t-1} \mid x_{t}, x_{0} \right)$ 分布中,重新给出该分布下的均值表达式子,整理可得: $\tilde{\mu}_{t}=\frac{1}{\sqrt{\bar{\alpha }_{t}}}\left(x_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha }_{t}}} \mathbf{z}_{t}\right)$,其中 $\mathbf{z}_{t}$ 为该学习模型中所预测的噪声(用于去噪),可以看做为 $\mathbf{z}_{\theta}\left(x_{t}, t\right)$,从而得到 $\mu_{\theta}\left(x_{t}, t\right)=\frac{1}{\sqrt{\bar{\alpha }_{t}}}\left(x_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha }_{t}}} \mathbf{z}_{\theta}\left(x_{t}, t\right)\right)$

上面的推导过程概括为:给定 $x_0$ 条件下,后验条件高斯分布均值计算只与 $x_t$$\mathbf{z}_{t}$ 有关。

优化损失函数

又回到了深度学习如何训练这个问题。然后训练 DDPM 得到靠谱的 $\mu_{\theta}\left(x_{t}, t\right)$$\Sigma_{\theta}\left(x_{t}, t\right)$ 呢?回到根本还是去优化损失函数。我们通过对真实数据分布下,最大化模型预测分布的对数似然,即优化在 $x_{0} \sim q\left(x_{0}\right)$ 下的 $p_{\theta}(x_{0})$交叉熵:$\mathcal{L}=\mathbb{E}_{q\left(x_{0}\right)}\left[-\log p_{\theta}\left(x_{0}\right)\right]$,接下来这就是我们的主角。这个过程很像 VAE,就是使用变分下限(VLB)去优化负对数似然,由于 KL 散度永远非负,可得到下面的推导式子:

$\begin{aligned} -\log p_{\theta}\left(x_{0}\right) & \leq-\log p_{\theta}\left(x_{0}\right)+D_{K L}\left(q\left(x_{1: T} \mid x_{0}\right)|| p_{\theta}\left(x_{1: T} \mid x_{0}\right)\right) \\ & =-\log p_{\theta}\left(x_{0}\right)+\mathbb{E}_{q\left(x_{1: T} \mid x_{0}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_{0}\right)}{p_{\theta}\left(x_{0: T}\right) / p_{\theta}\left(x_{0}\right)}\right] ; \quad \text { where } \quad p_{\theta}\left(x_{1: T} \mid x_{0}\right)=\frac{p_{\theta}\left(x_{0: T}\right)}{p_{\theta}\left(x_{0}\right)} \\ & =-\log p_{\theta}\left(x_{0}\right)+\mathbb{E}_{q\left(x_{1: T} \mid x_{0}\right)}[\log \frac{q\left(x_{1: T} \mid x_{0}\right)}{p_{\theta}\left(x_{0: T}\right)}+\underbrace{\log p_{\theta}\left(x_{0}\right)}_{\text {与 q 无关,移到外面抵消 }}] \\ & =\mathbb{E}_{q\left(x_{1: T} \mid x_{0}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_{0}\right)}{p_{\theta}\left(x_{0: T}\right)}\right] . \end{aligned}$

对最后一步得到的式子左右同时取期望,使用重积分中的富比尼定理,可得到:

$\mathcal{L}_{V L B}=\mathbb{E}_{q\left(x_{0}\right)}\left(\mathbb{E}_{q\left(x_{1: T} \mid x_{0}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_{0}\right)}{p_{\theta}\left(x_{0: T}\right)}\right]\right)=\mathbb{E}_{q\left(x_{0: T}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_{0}\right)}{p_{\theta}\left(x_{0: T}\right)}\right] \geq - \mathbb{E}_{q\left(x_{0}\right)}\left[\log p_{\theta}\left(x_{0}\right)\right]$

那么最小化 $\mathcal{L}_{V L B}$ 就能最小化我们的目标交叉熵损失函数 $\mathcal{L}=\mathbb{E}_{q\left(x_{0}\right)}\left[-\log p_{\theta}\left(x_{0}\right)\right]$

$\mathbb{E}_{q\left(x_{0: T}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_{0}\right)}{p_{\theta}\left(x_{0: T}\right)}\right] \geq-\mathbb{E}_{q\left(x_{0}\right)}\left[\log p_{\theta}\left(x_{0}\right)\right]$ ,右边是交叉熵,左边则是交叉熵上界,最小化交叉熵上界就是最小化的我们的损失函数。我们进一步写出上述式子的交叉熵的上界,并对其进行化简:

$\begin{array}{l} L_{\mathrm{VLB}}=\mathbb{E}_{q\left(\mathbf{x}_{0: T)}\right.}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0: T}\right)}\right] \quad (1)\\ =\mathbb{E}_{q}\left[\log \frac{\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}\right] \quad (2)\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=1}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}\right] \quad (3) \\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \quad (4) \\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)} \cdot \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right)}\right)+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \quad (5)\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \quad (6)\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \quad (7) \\ =\mathbb{E}_{q}\left[\log \frac{q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}-\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)\right] \quad (8)\\ =\mathbb{E}_{q}[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{T}\right)\right)}_{L_{T}}+\sum_{t=2}^{T} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right)}_{L_{t-1}}-\underbrace{\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}] \quad (9) \\ \end{array}$

有点小离谱,不要被吓到,等一下说说为什么看似变复杂的过程是在化简。

  • 第 (2) 个等式很好理解,就是将式子的意思展开为等式;

  • 第 (3) 个等式就是化简,把 $p_{\theta}$ 移出来,把剩下的进行整理;

  • 第 (4) 个等式 中括号内的第一项保持不变,右边求和把第一项单独拿出来,从第二项开始求和;

  • (5) 个等式是小巧妙点。如何从第 (4) 个等式变成第 (5) 个等式,很明显只有中括号内的求和那一项发生了改变,这里用到了这样的变化:这一项中的分母根据扩散过程是马尔可夫链其实可以从 $q\left(x_{t}\mid x_{t-1}\right)$ 等价为 $q\left(x_{t}\mid x_{t-1},x_{0}\right)$ ,而:

    $q\left(x_{t}\mid x_{t-1},x_{0}\right)=\frac{q\left(x_{t}, x_{t-1}, x_{0}\right)}{q\left(x_{t-1}, x_{0}\right)}=\underbrace{\frac{q\left(x_{t-1} \mid x_{t}, x_{0}\right) q\left(x_{t} \mid x_{0}\right) q\left(x_{0}\right)}{q\left(x_{t-1}, x_{0}\right)}}_{\text {同除} q({x_0})} =\frac{q\left(x_{t-1} \mid x_{t}, x_{0}\right) q\left(x_{t}, x_{0}\right)}{q\left(x_{t-1} \mid x_{0}\right)} $ ,然后代入就是 第 (5) 个等式;

  • 第 (6) 个等式就是将第 (5) 个等式中的 $log$ 乘变加拿出来,变成单独的两项,得到第 (6) 个等式中的四项;

  • (7) 个等式也是十分巧妙,观察可知,只有第三项变了,并且求和完竟然只变成了一项,实际上展开就能约掉:

    $\begin{aligned} & \sum_{t=2}^{T} \log \frac{q\left(x_{t} \mid x_{0}\right)}{q\left(x_{t-1} \mid x_{0}\right)} \\ = & \log \frac{q\left(x_{2} \mid x_{0}\right)}{q\left(x_{1} \mid x_{0}\right)}+\log \frac{q\left(x_{3} \mid x_{0}\right)}{q\left(x_{2} \mid x_{0}\right)}+\log \frac{q\left(x_{4} \mid x_{0}\right)}{q\left(x_{3} \mid x_{0}\right)}+\cdots+\log \frac{q\left(x_{T} \mid x_{0}\right)}{q\left(x_{T-1} \mid x_{0}\right)} \\ = & \log\left[\frac{q\left(x_{2} \mid x_{0}\right)}{q\left(x_{1} \mid x_{0}\right)} \cdot \frac{q\left(x_{3} \mid x_{0}\right)}{q\left(x_{2} \mid x_{0}\right)} \cdot \frac{q\left(x_{4}\left(x_{0}\right)\right.}{q\left(x_{3} \mid x_{0}\right)} \cdots \frac{q\left(x_{T} \mid x_{0}\right)}{q\left(x_{T-1} \mid x_{0}\right)}\right].\\ = & \log \frac{q\left(x_{T} \mid x_{0}\right)}{q\left(x_{1} \mid x_{0}\right)} \end{aligned}$

  • 第 (8) 个等式,首先就是后两项可以发现合并 $log$ 之后约掉 $q\left(x_{1}\mid x_{0}\right)$,然后把前面第一项的负号拿掉和前面合并的后两项进行合并,而 $p_{\theta}\left(x_{0}\mid x_{1}\right)$ 则变成单独的一项变成负 $log$

  • 第 (9) 个等式实际上就是转化成图上说明的各自的 KL 散度,在此等式中,$L_0$ 在 DDPM 原论文中由于选择了固定方差,故 $L_T$ 为常数,而 $L_0$ 相当于从连续空间到离散空间的解码 loss。

这里论文将 $p_{\theta}\left(x_{t-1}\mid x_{t}\right)$ 分布的方差设置成一个与 $\beta$ 相关的常数,因此可训练的参数只存在于均值中,对于两个单一变量的高斯分布 $p$$q$ 而言,他们的 KL 散度为:$K L(p, q)=\log \frac{\sigma_{2}}{\sigma_{1}}+\frac{\sigma^{2}+\left(\mu_{1}-\mu_{2}\right)^{2}}{2 \sigma_{2}^{2}}-\frac{1}{2}$

所以:

$L_{t-1}=\mathbb{E}_{q}\left[\frac{1}{2 \sigma_{t}^{2}}\left\|\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right]+C$

又:$\tilde{\mu}_{t}\left(x_{t}, x_{0}\right)=\frac{\sqrt{\alpha _{t}}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_{t}} x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{1-\bar{\alpha}_{t}} x_{0}$,代入上述式子得:

$\begin{aligned} L_{t-1}-C & =\mathbb{E}_{\mathbf{x}_{0}, \boldsymbol{\epsilon}}\left[\frac{1}{2 \sigma_{t}^{2}}\left\|\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right), \frac{1}{\sqrt{\bar{\alpha}_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right)-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}\right)\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right), t\right)\right\|^{2}\right] \\ & =\mathbb{E}_{\mathbf{x}_{0}, \boldsymbol{\epsilon}}\left[\frac{1}{2 \sigma_{t}^{2}}\left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right), t\right)\right\|^{2}\right] \end{aligned}$

又:$\boldsymbol{\mu}_{\theta}\left(\mathrm{x}_{t}, t\right)=\tilde{\boldsymbol{\mu}}_{t}\left(\mathrm{x}_{t}, \frac{1}{\sqrt{\bar{\alpha}_{t}}}\left(\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \epsilon_{\theta}\left(\mathrm{x}_{t}\right)\right)\right)=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathrm{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathrm{x}_{t}, t\right)\right)$

那么:$L_{t-1}$ 可以简化成如下表达式:

$\mathbb{E}_{\mathbf{x}_{0}, \boldsymbol{\epsilon}}\left[\frac{\beta_{t}^{2}}{2 \sigma_{t}^{2} \alpha_{t}\left(1-\bar{\alpha}_{t}\right)}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}, t\right)\right\|^{2}\right]$

其实就相当于让模型在给定 $x_t$$t$ 情况下能知道被加噪声是什么就可以还原了。这个过程在让反向扩散逼近后验分布。

最后呢,作者发现干脆把系数丢掉,训练更加稳定质量更好,于是就有了 $L_{\text {simple }}(\theta)$

$L_{\text {simple }}(\theta):=\mathbb{E}_{t, \mathbf{x}_{0}, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}, t\right)\right\|^{2}\right]$

从最后的式子也可以看出,DDPM并没有将模型预测的方差 $\Sigma_{\theta}\left(x_{t}, t\right)$ 考虑到训练和推断中,而是通过将未经训练的 $\beta_{t}$ 或者 $\tilde{\beta}_{t}$ 代替,因为他们发现 $\Sigma_{\theta}\left(x_{t}, t\right)$ 可能导致训练的不稳定。

总结

DDPM 原文看不懂的原因是作者只给出了这么一个图,实际上里面暗含了我们上述的推导,但是我们一旦把上述过程搞清楚,再去看这个图实际上也能让我们思路更加清晰一点:

image.png

DDPM 代码实现小 Demo

(待补充)

DDPM 改进

(待补充)

参考