当前位置: 首页 > news >正文

生成式人工智能实战 | 变分自编码器(Variational Auto-Encoder, VAE)

生成式人工智能实战 | 变分自编码器

    • 0. 前言
    • 1. 潜空间运算
    • 2. 变分自编码器
      • 2.1 VAE 工作原理
      • 2.2 VAE 构建策略
      • 2.3 KL 散度
      • 2.4 重参数化技巧
    • 3. 实现 VAE
      • 3.1 数据加载
      • 3.2 模型构建
      • 3.3 模型训练

0. 前言

虽然自编码器 (AutoEncoder, AE) 在重建输入数据方面表现良好,但通常在生成训练集中不存在的新样本时表现不佳。更重要的是,自编码器在输入插值方面同样表现不佳,无法生成两个输入数据点之间的中间表示。这就引出了变分自编码器 (Variational Auto-Encoder, VAE),变分自编码器是一种生成模型,结合了深度学习和概率图模型的优点,通过学习数据的潜在概率分布来生成新的数据样本。本节将从零开始构建和训练一个 VAE,使用 cifar-10 数据集训练 VAE

1. 潜空间运算

使用变分自编码器 (Variational Auto-Encoder, VAE) 可以进行向量运算和输入插值。操作不同输入的编码表示(潜向量),以在解码时实现特定的结果(例如,图像中是否具有某些特征)。潜向量控制解码图像中的不同特征,如性别、图像中是否有眼镜等。例如,可以首先获得戴眼镜的男性的潜向量 (z1)、戴眼镜的女性的潜向量 (z2) 和不戴眼镜的女性的潜向量 (z3)。然后,计算一个新的潜向量 z4 = z1 – z2 + z3。由于 z1z2 解码后都会出现眼镜,z1 – z2 会在结果图像中去除眼镜特征。类似地,由于 z2z3 都会解码为女性面孔,z3 – z2 会去除结果图像中的女性特征。因此,如果使用训练好的 VAE 解码 z4 将得到一张没有不戴眼镜的男性图像。

2. 变分自编码器

虽然自编码器 (AutoEncoder, AE) 擅长重建原始图像,但它们在生成训练集中没有出现的新图像方面表现不佳。此外,自编码器通常无法将相似的输入映射到潜空间中的相邻点。因此,AE 的潜空间既不连续,也不容易解释。例如,无法通过插值两个输入数据点来生成有意义的中间表示。基于这些原因,我们将学习自编码器的改进模型,变分自编码器 (Variational Auto-Encoder, VAE)。

2.1 VAE 工作原理

VAE 使用深度学习构建概率模型,将输入数据映射到一个低维度的潜空间中,并通过解码器将潜空间中的分布转换回数据空间中,以生成与原始数据相似的数据。与传统的自编码器相比,VAE 更加稳定,生成样本的质量更高。
VAE 的核心思想是利用概率模型来描述高维的输入数据,将输入数据采样于一个低维度的潜变量分布中,并通过解码器生成与原始数据相似的输出。具体来说,VAE 同样是由编码器和解码器组成:

  • 编码器将数据 x x x 映射到一个潜在空间 z z z 中,该空间定义在低维正态分布中,即 z ∼ N ( 0 , I ) z∼N(0,I) zN(0,I),编码器由两个部分组成:一是将数据映射到均值和方差,即 z ∼ N ( μ , σ 2 ) z∼N(μ,σ^2) zN(μ,σ2);二是通过重参数化技巧,将均值和方差的采样过程分离出来,并引入随机变量 ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵN(0,I),使得 z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ
  • 解码器将潜在变量 z z z 映射回数据空间中,生成与原始数据 x x x 相似的数据 x ′ x′ x,为了使生成的数据 x ′ x′ x 能够与原始数据 x x x 较高的相似度,VAE 在损失函数中使用重构误差和正则化项,重构误差表示生成数据与原始数据之间的差异,正则化项用于约束潜在变量的分布,使其满足高斯正态分布,使得 VAE 从潜空间中生成的样本质量更高

VAE 具有广泛的应用场景,如图像生成、语音、自然语言处理等领域,它能够通过有限的数据样本学习到输入数据背后的潜在规律,生成与原始数据类似的新数据,具有很强的潜数据的可解释性。

2.2 VAE 构建策略

VAE 中,基于预定义分布获得的随机向量生成逼真图像,而在传统自编码器中并未指定在网络中生成图像的数据分布。可以通过以下策略,实现 VAE:

  1. 编码器的输出包括两个向量:
    • 输入图像平均值
    • 输入图像标准差
  2. 根据以上两个向量,通过在均值和标准差之和中引入随机变量 ( ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵN(0,I)) 获取随机向量 ( z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ)
  3. 将上一步得到的随机向量作为输入传递给解码器以重构图像
  4. 损失函数是均方误差和 KL 散度损失的组合:
    • KL 散度损失衡量由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的分布与 N ( 0 , I ) N(0,I) N(0,I) 分布的偏差
    • 均方损失用于优化重建(解码)图像

通过训练网络,指定输入数据满足由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的 N ( 0 , 1 ) N(0,1) N(0,1) 分布,当我们生成均值为 0 且标准差为 1 的随机噪声时,解码器将能够生成逼真的图像。
需要注意的是,如果只最小化 KL 散度,编码器将预测均值向量为 0,标准差为 1。因此,需要同时最小化 KL 散度损失和均方损失。在下一节中,让我们介绍 KL 散度,以便将其纳入模型的损失值计算中。

2.3 KL 散度

KL 散度(也称相对熵)可以用于衡量两个概率分布之间的差异:

K L ( P ∣ ∣ Q ) = ∑ x ∈ X P ( x ) l n ( P ( i ) Q ( i ) ) KL(P||Q) = \sum_{x∈X} P(x) ln(\frac {P(i)}{Q(i)}) KL(P∣∣Q)=xXP(x)ln(Q(i)P(i))

其中, P P P Q Q Q 为两个概率分布,KL 散度的值越小,两个分布的相似性就越高,当且仅当 P P P Q Q Q 两个概率分布完全相同时,KL 散度等于 0。在 VAE 中,我们希望瓶颈特征值遵循平均值为 0 和标准差为 1 的正态分布。因此,我们可以使用 KL 散度衡量变分自编码器中编码器输出的分布与标准高斯分布 N ( 0 , 1 ) N(0,1) N(0,1) 之间的差异。
可以通过以下公式计算 KL 散度损失:

∑ i = 1 n σ i 2 + μ i 2 − l o g ∗ ( σ i ) − 1 \sum_{i=1}^n\sigma_i^2+\mu_i^2-log*(\sigma_i)-1 i=1nσi2+μi2log(σi)1

在上式中, σ σ σ μ μ μ 表示每个输入图像的均值和标准差值:

  • 确保均值向量分布在 0 附近:
    • 最小化上式中的均方误差 ( μ i 2 \mu_i^2 μi2) 可确保 μ \mu μ 尽可能接近 0
  • 确保标准差向量分布在 1 附近:
    • 上式中其余部分(除了 μ i 2 \mu_i^2 μi2 )用于确保标准差 ( s i g m a sigma sigma) 分布在 1 附近

当均值 ( μ μ μ) 为 0 且标准差为 1 时,以上损失函数值达到最小,通过引入标准差的对数,确保 σ \sigma σ 值不为负。通过最小化以上损失可以确保编码器输出遵循预定义分布。

2.4 重参数化技巧

下图左侧显示了 VAE 网络。编码器获取输入 x x x,并估计潜矢量 z z z 的多元高斯分布的均值 μ μ μ 和标准差 σ σ σ,解码器从潜矢量 z z z 采样,以将输入重构为 x x x
VAE
但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:

S a m p l e = μ + ε σ Sample=\mu + εσ Sample=μ+εσ

如果 ε ε ε σ σ σ 以矢量形式表示,则 ε σ εσ εσ 是逐元素乘法,使用上式,令采样好像直接来自于潜空间。 这种技术被称为重参数化技巧 (Reparameterization trick)。

3. 实现 VAE

在本节中,使用 PyTorch 实现 VAE 模型生成 cifar-10 图像。

3.1 数据加载

(1) 首先导入所需的库:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

(2) 定义数据预处理转换:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将像素值归一化到[-1,1]
])

(3) 加载 CIFAR-10 训练集和测试集:

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

(4) 创建数据加载器:

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

3.2 模型构建

(1) 定义 VAE 模型,由编码器和解码器构成:

class VAE(nn.Module):def __init__(self, latent_dim=128):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.encoder = nn.Sequential(nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # 32x16x16nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 64x8x8nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 128x4x4nn.ReLU(),nn.BatchNorm2d(128),nn.Flatten(),  # 128*4*4=2048nn.Linear(2048, 1024),nn.ReLU())# 潜在空间的均值和对数方差self.fc_mu = nn.Linear(1024, latent_dim)self.fc_logvar = nn.Linear(1024, latent_dim)# 解码器self.decoder_input = nn.Linear(latent_dim, 1024)self.decoder = nn.Sequential(nn.Linear(1024, 2048),nn.ReLU(),nn.Unflatten(1, (128, 4, 4)),  # 128x4x4nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 64x8x8nn.ReLU(),nn.BatchNorm2d(64),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x16x16nn.ReLU(),nn.BatchNorm2d(32),nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # 3x32x32nn.Tanh()  # 输出在[-1,1]之间,与输入归一化一致)def encode(self, x):"""编码输入图像x,返回潜在空间的均值和方差"""h = self.encoder(x)mu = self.fc_mu(h)logvar = self.fc_logvar(h)return mu, logvardef reparameterize(self, mu, logvar):"""重参数化技巧,从N(mu, var)采样"""std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):"""从潜在变量z解码重构图像"""h = self.decoder_input(z)x_recon = self.decoder(h)return x_recondef forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)x_recon = self.decode(z)return x_recon, mu, logvar

(2) 定义损失函数,由重建损失和 KL 散度组成:

def vae_loss(recon_x, x, mu, logvar):"""VAE损失函数 = 重构损失 + KL散度"""# 重构损失recon_loss = F.mse_loss(recon_x, x, reduction='sum')# KL散度:-0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return recon_loss + kl_loss

3.3 模型训练

(1) 定义模型训练和测试函数:

def train(model, device, train_loader, optimizer, epoch):model.train()train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()# 前向传播recon_batch, mu, logvar = model(data)# 计算损失loss = vae_loss(recon_batch, data, mu, logvar)# 反向传播和优化loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')avg_loss = train_loss / len(train_loader.dataset)print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')return avg_lossdef test(model, device, test_loader):model.eval()test_loss = 0with torch.no_grad():for data, _ in test_loader:data = data.to(device)recon_batch, mu, logvar = model(data)test_loss += vae_loss(recon_batch, data, mu, logvar).item()test_loss /= len(test_loader.dataset)print(f'====> Test set loss: {test_loss:.4f}')return test_loss

(2) 定义可视化函数,用于可视化原始图像和重构图像:

def visualize_reconstruction(model, device, test_loader, num_images=8):model.eval()with torch.no_grad():# 获取一批测试图像data, _ = next(iter(test_loader))data = data[:num_images].to(device)# 重构图像recon_data, _, _ = model(data)# 将图像从[-1,1]转换回[0,1]以便显示data = data.cpu().numpy().transpose(0, 2, 3, 1)data = (data + 1) / 2  # 从[-1,1]到[0,1]recon_data = recon_data.cpu().numpy().transpose(0, 2, 3, 1)recon_data = (recon_data + 1) / 2  # 从[-1,1]到[0,1]# 绘制图像fig, axes = plt.subplots(2, num_images, figsize=(num_images * 2, 4))for i in range(num_images):axes[0, i].imshow(data[i])axes[0, i].axis('off')axes[1, i].imshow(recon_data[i])axes[1, i].axis('off')axes[0, 0].set_ylabel('Original')axes[1, 0].set_ylabel('Reconstructed')plt.show()

(3) 定义 generate_samples(),从潜空间随机采样生成新图像:

def generate_samples(model, device, latent_dim, num_samples=16):model.eval()with torch.no_grad():# 从标准正态分布采样z = torch.randn(num_samples, latent_dim).to(device)# 生成样本samples = model.decode(z).cpu()samples = samples.numpy().transpose(0, 2, 3, 1)samples = (samples + 1) / 2  # 从[-1,1]到[0,1]# 绘制生成的样本fig, axes = plt.subplots(4, 4, figsize=(8, 8))for i, ax in enumerate(axes.flat):ax.imshow(samples[i])ax.axis('off')plt.show()

(4) 训练模型 50epoch,训练完成后,可视化模型生成效果,并绘制训练和测试损失变化曲线:

def main():# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 初始化模型latent_dim = 128model = VAE(latent_dim=latent_dim).to(device)# 定义优化器optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练参数epochs = 50train_losses = []test_losses = []# 训练循环for epoch in range(1, epochs + 1):train_loss = train(model, device, train_loader, optimizer, epoch)test_loss = test(model, device, test_loader)train_losses.append(train_loss)test_losses.append(test_loss)# 每5个epoch可视化一次if epoch % 5 == 0:visualize_reconstruction(model, device, test_loader)# 训练完成后可视化generate_samples(model, device, latent_dim)# 绘制训练和测试损失曲线plt.figure(figsize=(10, 5))plt.plot(train_losses, label='Train Loss')plt.plot(test_losses, label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Test Loss')plt.show()main()

重建效果:

重建效果
生成结果:

生成结果

http://www.xdnf.cn/news/1074799.html

相关文章:

  • 基于STM32温湿度检测—串口显示
  • HTML5 实现的圣诞主题网站源码,使用了 HTML5 和 CSS3 技术,界面美观、节日氛围浓厚。
  • k8s pod深度解析
  • k8s创建定时的 Python 任务(CronJob)
  • 【c/c++1】数据类型/指针/结构体,static/extern/makefile/文件
  • 机器学习9——决策树
  • 新生代潜力股刘小北:演艺路上的璀璨新星
  • ROS常用的路径规划算法介绍
  • 面试复盘6.0
  • Java面试宝典:基础四
  • SpringSecurity6-oauth2-三方gitee授权-授权码模式
  • 详解快速排序
  • 宏任务与微任务和Dom渲染的关系
  • 左神算法之螺旋打印
  • Redis Cluster Gossip 协议
  • 在Linux系统中部署Java项目
  • 设计模式之装饰者模式
  • 2.安装Docker
  • 怎样学习STM32
  • 暴力风扇方案介绍
  • HarmonyOS实战:自定义表情键盘
  • FPGA实现CameraLink视频解码,基于Xilinx ISERDES2原语,提供4套工程源码和技术支持
  • llama.cpp学习笔记:后端加载
  • 图书管理系统练习项目源码-前后端分离-使用node.js来做后端开发
  • Conda 环境配置之 -- Mamba安装(causal-conv1d、mamba_ssm 最简单配置方法)-- 不需要重新配置CDUA
  • 领域驱动设计(DDD)【26】之CQRS模式初探
  • AlpineLinux安装部署elasticsearch
  • Kafka4.0初体验
  • Python爬虫:Requests与Beautiful Soup库详解
  • 重写(Override)与重载(Overload)深度解析