VAE变分自编码器详解
1️⃣ 为什么提出VAE?
自编码器的局限性
自编码器的结构如下图所示:
基于神经网络强大的拟合能力,使得编码(code)的维度比原始图像小很多,而且解码后图像和原图像接近。我们基于自编码器实现了重构功能。
然而基于上述自编码器没有办法来产生任何新内容,原因分析如下:
如上图所示,假设有两张训练图片,一张是全月图,一张是半月图,经过训练我们的自编码器模型已经能无损地还原这两张图片。接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。一个比较合理的解释是,因为编码和解码的过程使用了深度神经网络,这是一个非线性的变换过程,所以在code空间上点与点之间的迁移是非常没有规律的。
如何解决这个问题呢?我们可以引入噪声,使得图片的编码区域得到扩大,从而掩盖掉失真的空白编码点。
如上图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内,于是在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。
此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点。为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。
也就是说,自编码器将输入编码为隐空间中的单个点,而VAE是将其编码为隐空间中的一个分布。
2️⃣ VAE具体原理
VAE的结构
上面介绍到,自编码器将输入编码为隐空间中的单个点,而VAE是将其编码为隐空间中的概率分布,具体流程为:
- 首先,将输入编码为在隐空间上的分布;
- 第二,从该分布中采样隐空间中的一个点(重参数化);
- 第三,对采样点进行解码并计算出重建误差;
- 最后,重建误差通过网络反向传播。
总而言之,编码器学习输入数据的分布;解码器从分布中采样并还原成最终的输出。
在实践中,隐空间的分布强制为高斯分布
,可以训练编码器来输出高斯分布的均值和协方差矩阵。
因此,在训练VAE时最小化的损失函数由两部分构成:
- “重构项”(在最后一层):“重构项”倾向于使编码-解码方案尽可能地具有高性能
- “正则化项”(在隐层):通过使编码器输出的分布接近标准正态分布,来规范隐空间。该正则化项为编码器输出的分布与标准高斯之间的Kulback-Leibler散度
注:两个高斯分布之间的Kullback-Leibler散度具有封闭形式,可以直接用两个分布的均值和协方差矩阵表示。
KL散度又称为相对熵,其定义为 K L ( p ( x ) , q ( x ) ) = ∑ p ( x ) l o g p ( x ) q ( x ) KL(p(x),q(x))=\sum p(x)log\frac{p(x)}{q(x)} KL(p(x),q(x))=∑p(x)logq(x)p(x)。这个概念很重要,不只是VAE,很多地方都会用到。
关于“正则化项”的直观解释
为了使生成过程成为可能,我们期望隐空间具有规则性,这可以通过两个主要属性表示:
- 连续性(continuity,隐空间中的两个相邻点解码后不应呈现两个完全不同的内容)
- 完整性(completeness,针对给定的分布,从隐空间采样的点在解码后应提供“有意义”的内容)。
上图中左边所展示的是不规则隐空间,黑色圆圈中的两个点临近,但是解码后不相似;紫色的点,解码后没有意义。
然而,VAE如果不定义正则化项,则模型会像自编码器一样,仅最小化重构误差。具体而言,编码器返回的分布方差极小(往往是点分布)或或者返回具有巨大均值差异的分布(数据在隐空间中彼此相距很远)。在这两种情况下,返回分布的限制都没有取得效果,并且不满足连续性和/或完整性。
因此,为了避免上述影响,必须同时对协方差矩阵和分布均值进行正则化。在实际中,通过强制分布接近标准正态分布
来完成此正则化。这样,我们要求协方差矩阵接近于单位阵,防止出现单点分布,并且均值接近于0,防止编码分布彼此相距太远。
使用正则化项,可以满足预期的连续性和完整性条件。然而,使用正则化项,会以训练数据上更高的重建误差为代价,因此需要调整重建误差和KL散度之间的权重
通过正则化,会有以下现象,从分布的“重叠区域”采样的点,解码出来会是“综合形状”。例如,三角编码到橙色隐空间的中心点,如果采样这个点,重建的就是三角;圆圈编码到蓝色隐空间的中心点,如果采样这个点,重建的就是圆圈;如果此时采样蓝色隐空间和橙色隐空间中间位置的点,那重建的应该是圆滑的三角形。
VAE的数学解释
我们有一堆数据 X X X,比如手写数字图片,想找到它们的概率分布 p ( X ) p(X) p(X)。如果知道 p ( X ) p(X) p(X),我们就能从中采样,生成新的类似图片。 p ( X ) p(X) p(X) 越大,模型越能“理解”训练数据(手写图片)的规律。然而,直接求 p ( X ) p(X) p(X)太难,因为图片高维、复杂。所以现在的目标是,如何求 p ( X ) p(X) p(X)并让它最大化。
为了解决上述问题,可以假设数据 X X X 由一些隐藏的 z z z(潜在变量)生成。 z z z服从一个简单的分布 p ( z ) p(z) p(z),比如标准正态分布 N ( 0 , 1 ) N(0, 1) N(0,1)。用一个“生成规则”,即 p ( X ∣ z ) p(X|z) p(X∣z) 从 z z z 产生 X X X。假设数据 X X X 由潜在变量 z z z 生成:
p ( X ) = ∫ p ( X ∣ z ) p ( z ) d z p(X) = \int p(X|z) p(z) \, dz p(X)=∫p(X∣z)p(z)dz
想求 p ( X ) p(X) p(X),就得求这个积分,但是这积分高维,也求不出来。
那还有其他方法吗?根据贝叶斯公式,我们可以得到:
p ( z ∣ X ) = p ( X ∣ z ) p ( z ) p ( X ) p(z|X) = \frac{p(X|z) p(z)}{p(X)} p(z∣X)=p(X)p(X∣z)p(z)
- 其中 p ( z ) p(z) p(z)是先验分布,通常是标准正态分布 p ( z ) ∼ N ( 0 , 1 ) p(z) \sim N(0, 1) p(z)∼N(0,1);
- p ( z ∣ X ) p(z|X) p(z∣X)是后验分布,给定数 X X X,推断潜在变量 z z z 的概率
- p ( X ∣ z ) p(X|z) p(X∣z):条件概率,给定 z z z,生成数据 X X X 的概率(似然)
我们想要通过 p ( z ∣ X ) p(z|X) p(z∣X)求 p ( X ) p(X) p(X),而 p ( X ) p(X) p(X)正是我们想求的,陷入循环。
既然如此,尝试用变分推断间接优化:
构造一个简单的分布 q ( z ∣ X ) q(z|X) q(z∣X),用神经网络(编码器)学习,输出均值和方差,近似 p ( z ∣ X ) p(z|X) p(z∣X)。如何衡量 q ( z ∣ X ) q(z|X) q(z∣X)和 p ( z ∣ X ) p(z|X) p(z∣X)和差异呢?那就是KL散度:
D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) = E q ( z ∣ X ) [ log q ( z ∣ X ) p ( z ∣ X ) ] D_{KL}(q(z|X) \| p(z|X)) = \mathbb{E}_{q(z|X)}[\log \frac{q(z|X)}{p(z|X)}] DKL(q(z∣X)∥p(z∣X))=Eq(z∣X)[logp(z∣X)q(z∣X)]
进一步得到:
D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) = E q ( z ∣ X ) [ log q ( z ∣ X ) ] − E q ( z ∣ X ) [ log p ( z ∣ X ) ] D_{KL}(q(z|X) \| p(z|X)) = \mathbb{E}_{q(z|X)}[\log q(z|X)] - \mathbb{E}_{q(z|X)}[\log p(z|X)] DKL(q(z∣X)∥p(z∣X))=Eq(z∣X)[logq(z∣X)]−Eq(z∣X)[logp(z∣X)]
代入贝叶斯公式:
用 p ( z ∣ X ) = p ( X ∣ z ) p ( z ) p ( X ) p(z|X) = \frac{p(X|z) p(z)}{p(X)} p(z∣X)=p(X)p(X∣z)p(z) 替换:
D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) = E q ( z ∣ X ) [ log q ( z ∣ X ) ] − E q ( z ∣ X ) [ log p ( X ∣ z ) p ( z ) p ( X ) ] D_{KL}(q(z|X) \| p(z|X)) = \mathbb{E}_{q(z|X)}[\log q(z|X)] - \mathbb{E}_{q(z|X)}[\log \frac{p(X|z) p(z)}{p(X)}] DKL(q(z∣X)∥p(z∣X))=Eq(z∣X)[logq(z∣X)]−Eq(z∣X)[logp(X)p(X∣z)p(z)]
展开对数:
因为 log p ( X ∣ z ) p ( z ) p ( X ) = log p ( X ∣ z ) + log p ( z ) − log p ( X ) \log \frac{p(X|z) p(z)}{p(X)} = \log p(X|z) + \log p(z) - \log p(X) logp(X)p(X∣z)p(z)=logp(X∣z)+logp(z)−logp(X),代入得到:
D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) = E q ( z ∣ X ) [ log q ( z ∣ X ) ] − E q ( z ∣ X ) [ log p ( X ∣ z ) ] − E q ( z ∣ X ) [ log p ( z ) ] + E q ( z ∣ X ) [ log p ( X ) ] D_{KL}(q(z|X) \| p(z|X)) = \mathbb{E}_{q(z|X)}[\log q(z|X)] - \mathbb{E}_{q(z|X)}[\log p(X|z)] - \mathbb{E}_{q(z|X)}[\log p(z)] + \mathbb{E}_{q(z|X)}[\log p(X)] DKL(q(z∣X)∥p(z∣X))=Eq(z∣X)[logq(z∣X)]−Eq(z∣X)[logp(X∣z)]−Eq(z∣X)[logp(z)]+Eq(z∣X)[logp(X)]
简化,注意到 log p ( X ) \log p(X) logp(X) 不依赖 z z z,所以期望可以提出来:
E q ( z ∣ X ) [ log p ( X ) ] = log p ( X ) \mathbb{E}_{q(z|X)}[\log p(X)] = \log p(X) Eq(z∣X)[logp(X)]=logp(X)
于是公式变成:
D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) = E q ( z ∣ X ) [ log q ( z ∣ X ) ] − E q ( z ∣ X ) [ log p ( X ∣ z ) ] − E q ( z ∣ X ) [ log p ( z ) ] + log p ( X ) D_{KL}(q(z|X) \| p(z|X)) = \mathbb{E}_{q(z|X)}[\log q(z|X)] - \mathbb{E}_{q(z|X)}[\log p(X|z)] - \mathbb{E}_{q(z|X)}[\log p(z)] + \log p(X) DKL(q(z∣X)∥p(z∣X))=Eq(z∣X)[logq(z∣X)]−Eq(z∣X)[logp(X∣z)]−Eq(z∣X)[logp(z)]+logp(X)
移项,把 log p ( X ) \log p(X) logp(X) 移到左边:
log p ( X ) = E q ( z ∣ X ) [ log p ( X ∣ z ) ] − E q ( z ∣ X ) [ log q ( z ∣ X ) ] + E q ( z ∣ X ) [ log p ( z ) ] + D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) \log p(X) = \mathbb{E}_{q(z|X)}[\log p(X|z)] - \mathbb{E}_{q(z|X)}[\log q(z|X)] + \mathbb{E}_{q(z|X)}[\log p(z)] + D_{KL}(q(z|X) \| p(z|X)) logp(X)=Eq(z∣X)[logp(X∣z)]−Eq(z∣X)[logq(z∣X)]+Eq(z∣X)[logp(z)]+DKL(q(z∣X)∥p(z∣X))
注意到 D K L ( q ( z ∣ X ) ∥ p ( z ) ) = E q ( z ∣ X ) [ log q ( z ∣ X ) p ( z ) ] D_{KL}(q(z|X) \| p(z)) = \mathbb{E}_{q(z|X)}[\log \frac{q(z|X)}{p(z)}] DKL(q(z∣X)∥p(z))=Eq(z∣X)[logp(z)q(z∣X)],所以:
E q ( z ∣ X ) [ log q ( z ∣ X ) ] − E q ( z ∣ X ) [ log p ( z ) ] = − D K L ( q ( z ∣ X ) ∥ p ( z ) ) \mathbb{E}_{q(z|X)}[\log q(z|X)] - \mathbb{E}_{q(z|X)}[\log p(z)] = -D_{KL}(q(z|X) \| p(z)) Eq(z∣X)[logq(z∣X)]−Eq(z∣X)[logp(z)]=−DKL(q(z∣X)∥p(z))
代入后,公式整理为:
log p ( X ) = E q ( z ∣ X ) [ log p ( X ∣ z ) ] − D K L ( q ( z ∣ X ) ∥ p ( z ) ) + D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) \log p(X) = \mathbb{E}_{q(z|X)}[\log p(X|z)] - D_{KL}(q(z|X) \| p(z)) + D_{KL}(q(z|X) \| p(z|X)) logp(X)=Eq(z∣X)[logp(X∣z)]−DKL(q(z∣X)∥p(z))+DKL(q(z∣X)∥p(z∣X))
最后一项 D K L ( q ( z ∣ X ) ∥ p ( z ∣ X ) ) D_{KL}(q(z|X) \| p(z|X)) DKL(q(z∣X)∥p(z∣X)) 算不出,因为 p ( z ∣ X ) p(z|X) p(z∣X) 未知,但注意到KL散度非负,所以有:
log p ( X ) ≥ E q ( z ∣ X ) [ log p ( X ∣ z ) ] − D K L ( q ( z ∣ X ) ∥ p ( z ) ) \log p(X) \geq \mathbb{E}_{q(z|X)}[\log p(X|z)] - D_{KL}(q(z|X) \| p(z)) logp(X)≥Eq(z∣X)[logp(X∣z)]−DKL(q(z∣X)∥p(z))
右边的式子称为:证据下界(ELBO)
想要最大化 log p ( X ) \log p(X) logp(X),就是最大化ELBO:
- 重构项 E q ( z ∣ X ) [ log p ( X ∣ z ) ] \mathbb{E}_{q(z|X)}[\log p(X|z)] Eq(z∣X)[logp(X∣z)]:从 q ( z ∣ X ) q(z|X) q(z∣X) 采样 z z z,用 p ( X ∣ z ) p(X|z) p(X∣z) 重建 X X X,希望重建结果接近原始数据。
- KL散度项 − D K L ( q ( z ∣ X ) ∥ p ( z ) ) -D_{KL}(q(z|X) \| p(z)) −DKL(q(z∣X)∥p(z)):让 q ( z ∣ X ) q(z|X) q(z∣X) 接近已知的先验 p ( z ) ∼ N ( 0 , 1 ) p(z) \sim N(0, 1) p(z)∼N(0,1)
因此损失函数就是最小化负的ELBO:
loss = − E q ( z ∣ X ) [ log p ( X ∣ z ) ] + D K L ( q ( z ∣ X ) ∥ p ( z ) ) \text{loss} = -\mathbb{E}_{q(z|X)}[\log p(X|z)] + D_{KL}(q(z|X) \| p(z)) loss=−Eq(z∣X)[logp(X∣z)]+DKL(q(z∣X)∥p(z))
对应到实际代码中:
- − E q ( z ∣ X ) [ log p ( X ∣ z ) ] ≈ MSE ( X , X ^ ) -\mathbb{E}_{q(z|X)}[\log p(X|z)] \approx \text{MSE}(X, \hat{X}) −Eq(z∣X)[logp(X∣z)]≈MSE(X,X^)
- D K L ( q ( z ∣ X ) ∥ p ( z ) ) = − 0.5 ∗ ∑ ( 1 + log ( σ 2 ) − μ 2 − σ 2 ) D_{KL}(q(z|X) \| p(z))= -0.5 * \sum \left(1 + \log(\sigma^2) - \mu^2 - \sigma^2\right) DKL(q(z∣X)∥p(z))=−0.5∗∑(1+log(σ2)−μ2−σ2)
实际应用时,该重建函数即可。KL是固定的。
重参数化
在VAE中,我们希望通过编码器学习一个后验分布 q ( z ∣ x ) q(z|x) q(z∣x),表示给定输入数据 x x x,潜在变量 z z z 的概率分布。通常假设 q ( z ∣ x ) ∼ N ( μ , σ 2 ) q(z|x) \sim N(\mu, \sigma^2) q(z∣x)∼N(μ,σ2),即一个均值为 μ \mu μ,方差为 σ 2 \sigma^2 σ2 的正态分布
。
训练神经网络需要通过梯度下降优化参数,但直接从 q ( z ∣ x ) ∼ N ( μ , σ 2 ) q(z|x) \sim N(\mu, \sigma^2) q(z∣x)∼N(μ,σ2) 中采样 z z z 是一个随机过程,梯度无法直接通过采样操作传递到 μ \mu μ 和 σ \sigma σ。这是因为随机采样的操作(如从正态分布生成随机数)不可微,无法计算梯度。
为什么采样不可微:一个函数是“可微的”意味着它有定义良好的导数,导数要求输出对输入的变化是平滑且可预测的。但是采样的随机性导致输出与输入的关系不连续、不确定,因此没办法求导数。
我们需要一种方法,让 z z z 的生成过程既能保留随机性,又能让梯度通过 μ \mu μ 和 σ \sigma σ 传播,以便优化编码器的参数。
这时候,重参数化登场了:
对于一个服从正态分布
的随机变量 z ∼ N ( μ , σ 2 ) z \sim N(\mu, \sigma^2) z∼N(μ,σ2),我们可以通过标准化将其转换为标准正态分布:
z − μ σ ∼ N ( 0 , 1 ) \frac{z - \mu}{\sigma} \sim N(0, 1) σz−μ∼N(0,1)
- 这里, z − μ z - \mu z−μ 消除了均值的影响,平移分布使均值为 0。
- 除以标准差 σ \sigma σ 则标准化了方差,使其变为 1。
- 结果是 z − μ σ \frac{z - \mu}{\sigma} σz−μ 服从
标准正态分布
N ( 0 , 1 ) N(0, 1) N(0,1),均值为 0,方差为 1。
基于标准化公式,我们可以定义一个噪声项 ϵ \epsilon ϵ:
ϵ = z − μ σ \epsilon = \frac{z - \mu}{\sigma} ϵ=σz−μ
由于 z − μ σ ∼ N ( 0 , 1 ) \frac{z - \mu}{\sigma} \sim N(0, 1) σz−μ∼N(0,1),我们有 ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0, 1) ϵ∼N(0,1),即 ϵ \epsilon ϵ 是从标准正态分布
中采样的随机变量。
将公式重新整理, z z z可以表示为:
z = μ + ϵ ⋅ σ z = \mu + \epsilon \cdot \sigma z=μ+ϵ⋅σ
这里的 μ \mu μ 和 σ \sigma σ 是编码器的输出(均值和标准差),由神经网络根据输入 x x x 计算得到。
ϵ \epsilon ϵ 是从标准正态分布 N ( 0 , 1 ) N(0, 1) N(0,1) 中独立采样的噪声,引入随机性。
综上所述,总结一下:
原始方式:直接从 N ( μ , σ 2 ) N(\mu, \sigma^2) N(μ,σ2) 采样 z z z,采样过程不可微,梯度无法传递到 μ \mu μ 和 σ \sigma σ。
重参数化方式:
- 先从标准正态分布 N ( 0 , 1 ) N(0, 1) N(0,1) 采样一个独立的噪声 ϵ \epsilon ϵ,这是一个固定的随机过程,与网络参数无关。
- 然后通过确定性变换 z = μ + ϵ ⋅ σ z = \mu + \epsilon \cdot \sigma z=μ+ϵ⋅σ 计算 z z z。
- 现在, z z z 是 μ \mu μ 和 σ \sigma σ 的函数,而 μ \mu μ 和 σ \sigma σ 是神经网络的输出,梯度可以通过这个确定性变换传递到网络参数。
3️⃣ 代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.autograd import Variable# 定义变分自编码器(VAE)模型
class VAE(nn.Module):def __init__(self, input_dim, hidden_dim, latent_dim):super(VAE, self).__init__()# 编码器用于近似后验分布 q(z|x)self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, latent_dim * 2) # 输出均值和对数方差log(σ^2))# 解码器建模条件分布p(x|z),即从潜在变量z重建输入数据xself.decoder = nn.Sequential(nn.Linear(latent_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, input_dim),nn.Sigmoid())def reparameterize(self, mu, logvar):# mu:编码器输出的均值# logvar[即log(σ^2)]:编码器输出的对数方差# 方差σ^2 = e^{log(σ^2)}# 因此标准差 σ=std=e^{0.5*log(σ^2)}std = torch.exp(0.5 * logvar) # 生成与标准差std形状相同的随机噪声,服从标准正态分布N(0, 1)eps = torch.randn_like(std)# 通过重参数化技巧计算潜在变量zz = mu + eps * std return zdef forward(self, x):# 编码encoded = self.encoder(x)# 将输出分割为均值和方差mu, logvar = torch.chunk(encoded, 2, dim=1) # 通过重参数化技巧从q(z|x)采样潜在变量z z = self.reparameterize(mu, logvar) # 解码decoded = self.decoder(z)return decoded, mu, logvar# 定义训练函数
def train_vae(model, train_loader, num_epochs, learning_rate):#criterion = nn.MSELoss() # 二元交叉熵损失函数,适用于自然图像criterion = nn.BCELoss() # 二元交叉熵损失函数, 适用于二值图像optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Adam优化器model.train() # 设置模型为训练模式for epoch in range(num_epochs):total_loss = 0.0for data in train_loader:images, _ = dataimages = images.view(images.size(0), -1) # 展平输入图像optimizer.zero_grad()# 前向传播outputs, mu, logvar = model(images)# 计算重构损失reconstruction_loss = criterion(outputs, images)# KL散度,衡量q(z|x)与先验p(z) ~ N(0, 1)的差异# 计算公式为:具体见下面的分析kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# 计算总损失loss = reconstruction_loss + kl_divergence# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()# 输出当前训练轮次的损失print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, total_loss / len(train_loader)))print('Training finished.')# 示例用法
if __name__ == '__main__':# 设置超参数input_dim = 784 # 输入维度(MNIST图像的大小为28x28,展平后为784)hidden_dim = 256 # 隐层维度latent_dim = 64 # 潜在空间维度num_epochs = 10 # 训练轮次learning_rate = 0.001 # 学习率# 加载MNIST数据集train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)# 创建VAE模型model = VAE(input_dim, hidden_dim, latent_dim)# 训练VAE模型train_vae(model, train_loader, num_epochs, learning_rate)
这里面有几个疑惑:
-
为什么encoder输出的是对数方差?
答:标准差可以表示为std或 σ \sigma σ,而方差是 σ 2 \sigma^2 σ2,它一定是非负的。如果我们将神经网络的输出直接定义为 σ 2 \sigma^2 σ2,就会有问题,因为神经网络可能会输出负值,在数学上没有意义。
因此,我们假设神经网络输出的是方差的对数,即 log ( σ 2 ) \log(\sigma^2) log(σ2),它是负数也没关系,因为,我们想得到 σ 2 \sigma^2 σ2的话,需要对它取指数,即 e log ( σ 2 ) e^{\log(\sigma^2)} elog(σ2),这样即使对数方差是负的,取指数后依然可以得到正数的 σ 2 \sigma^2 σ2。 -
损失函数就是最小化负的ELBO,表达式为:
loss = − E q ( z ∣ X ) [ log p ( X ∣ z ) ] + D K L ( q ( z ∣ X ) ∥ p ( z ) ) \text{loss} = -\mathbb{E}_{q(z|X)}[\log p(X|z)] + D_{KL}(q(z|X) \| p(z)) loss=−Eq(z∣X)[logp(X∣z)]+DKL(q(z∣X)∥p(z))-
− E q ( z ∣ X ) [ log p ( X ∣ z ) ] ≈ MSE ( X , X ^ ) -\mathbb{E}_{q(z|X)}[\log p(X|z)] \approx \text{MSE}(X, \hat{X}) −Eq(z∣X)[logp(X∣z)]≈MSE(X,X^)
-
D K L ( q ( z ∣ X ) ∥ p ( z ) ) = − 0.5 ∗ ∑ ( 1 + log ( σ 2 ) − μ 2 − σ 2 ) D_{KL}(q(z|X) \| p(z))= -0.5 * \sum \left(1 + \log(\sigma^2) - \mu^2 - \sigma^2\right) DKL(q(z∣X)∥p(z))=−0.5∗∑(1+log(σ2)−μ2−σ2)
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())loss=MSE+kl_divergence
-
4️⃣ 知识点
高斯分布
高斯分布,也称为正态分布,是一种连续概率分布,广泛用于统计和概率论中。其概率密度函数呈钟形曲线,对称于均值。以下是其关键特点:
数学表达式:概率密度函数为
f ( x ) = 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 f(x) = \frac{1}{\sqrt{2\pi\sigma^2}} e^{-\frac{(x-\mu)^2}{2\sigma^2}} f(x)=2πσ21e−2σ2(x−μ)2
其中, μ \mu μ 是均值(决定分布的中心), σ \sigma σ 是标准差(决定分布的宽度), σ 2 \sigma^2 σ2 是方差。
特性:
对称性:曲线以均值为中心,左右对称。
钟形:数据集中在均值附近,两侧逐渐减少。
68-95-99.7 法则:在 μ ± σ \mu \pm \sigma μ±σ 范围内包含约 68% 的数据, μ ± 2 σ \mu \pm 2\sigma μ±2σ 包含约 95%, μ ± 3 σ \mu \pm 3\sigma μ±3σ 包含约 99.7%
先验概率和后验概率
先验概率:
- 定义:在获取新的证据或数据之前,基于已有知识或假设对某一事件发生概率的估计。
- 用 P ( A ) P(A) P(A) 表示,其中 A A A是某个事件
后验概率:
- 定义:在给定新的证据或数据后,通过贝叶斯定理更新后的条件概率。
- 根据贝叶斯定理, P ( A ∣ B ) = P ( B ∣ A ) ⋅ P ( A ) P ( B ) P(A|B) = \frac{P(B|A) \cdot P(A)}{P(B)} P(A∣B)=P(B)P(B∣A)⋅P(A)
其中:
P ( A ∣ B ) P(A|B) P(A∣B):后验概率,即在给定 B B B 后, A A A 发生的概率。
P ( B ∣ A ) P(B|A) P(B∣A):似然概率,即在事件 A A A 发生时证据 B B B 出现的概率。
P ( A ) P(A) P(A):先验概率。
P ( B ) P(B) P(B):证据 B B B的总概率, P ( B ) = P ( B ∣ A ) ⋅ P ( A ) + P ( B ∣ 非 A ) ⋅ P ( 非 A ) P(B) = P(B|A) \cdot P(A) + P(B|\text{非}A) \cdot P(\text{非}A) P(B)=P(B∣A)⋅P(A)+P(B∣非A)⋅P(非A)
A A A:感染新冠病毒
B B B:新冠检测为阳性
所以 P ( A ∣ B ) P(A|B) P(A∣B),后验概率表示,检测阳性后实际感染新冠的概率。
5️⃣ 总结
自编码器只能进行重建,无法生成新样本。VAE通过引入正则化项,使得隐空间的分布从某种先验分布(通常是标准正态分布),具有连续性和可解释性,使得模型不仅能重建数据,还能从隐空间分布采样生成新数据,适用于生成任务
6️⃣ 参考
半小时理解变分自编码器
生成模型——变分自编码器