当前位置: 首页 > news >正文

PyTorch生成式人工智能——深度分层变分自编码器(NVAE)详解与实现

PyTorch生成式人工智能——深度分层变分自编码器(NVAE)详解与实现

    • 0. 前言
    • 1. NVAE 技术原理
      • 1.1 变分自编码器基础
      • 1.2 深度分层架构
      • 1.3 多尺度架构设计
    • 2. 残差单元与可分离卷积
    • 3. 残差参数化与后验分布
    • 4. 使用 PyTorch 构建 NVAE
      • 4.1 数据集加载
      • 4.2 模型构建与训练
    • 相关链接

0. 前言

变分自编码器 (Variational Autoencoder, VAE) 作为深度学习生成模型的重要分支,具有独特的优势,与生成对抗网络 (Generative Adversarial Network, GAN) 和自回归模型相比,VAE 具有采样速度快、计算可处理性强以及编码网络易于访问等优势。然而,传统的 VAE 模型在生成质量上往往落后于其他先进生成模型,尤其是在处理高分辨率自然图像时表现不佳。为了应对这一挑战,深度分层变分自编码器 (Nouveau VAE, NVAE) 通过神经架构设计的创新,推动了 VAE 性能的提升。

1. NVAE 技术原理

1.1 变分自编码器基础

传统变分自编码器 (Variational Autoencoder, VAE) 由编码器和解码器组成。编码器将输入数据 xxx 映射到潜空间的后验分布 q(z∣x)q(z|x)q(zx),解码器从潜在变量 zzz 重建数据 xxxVA E的训练目标是最大化证据下界 (Evidence Lower Bound, ELBO):
log⁡p(x)≥Eqϕ(z∣x)[log⁡pθ(x∣z)]−DKL(qϕ(z∣x)∣∣p(z))\text {log}⁡p(x)≥\mathbb E_{q_ϕ(z|x)}[\text {log}⁡p_θ(x|z)]−D_{KL}(q_ϕ(z|x)||p(z)) logp(x)Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)∣∣p(z))
其中第一项是重建损失,第二项是潜空间分布与先验分布的 KL 散度

1.2 深度分层架构

NVAE 采用深度分层架构,将潜变量分为 LLL 组:z=z1,z2,...,zLz = {z_1, z_2, ..., z_L}z=z1,z2,...,zL,其中 z1z_1z1 是最底层(最抽象)的变量,zLz_LzL 是最高层(最接近输入)的变量,形成了层次化的潜表示。这种设计使得先验和后验分布都变成了联合分布,能够在不同层次上捕获数据的抽象特征:
pθ(x,z1:L)=pθ(x∣z1:L)∏i=1Lpθ(zi∣zi+1:L)qφ(z1:L∣x)=∏i=1Lqφ(zi∣zi+1:L,x)p_θ(x, z_{1:L}) = p_θ(x|z_{1:L})∏_{i=1}^L p_θ(z_i|z_{i+1:L})\\ q_φ(z_{1:L}|x) = ∏_{i=1}^L q_φ(z_i|z_{i+1:L}, x) pθ(x,z1:L)=pθ(xz1:L)i=1Lpθ(zizi+1:L)qφ(z1:Lx)=i=1Lqφ(zizi+1:L,x)
这种分层设计允许模型在多个分辨率级别上处理输入数据,较低层次的组捕获细节信息,而较高层次的组捕获语义级别的抽象信息。这与人类视觉系统的层次化处理方式相似,使得模型能够生成全局一致且细节丰富的高分辨率图像。

1.3 多尺度架构设计

NVAE 采用了多尺度架构,在处理图像时在不同层次使用不同的分辨率。编码器逐步降低输入图像的分辨率,同时增加通道数;而解码器则执行相反的过程,逐步上采样并减少通道数。这种设计使得计算更加高效,同时保持了模型对细节和全局结构的表现能力。

2. 残差单元与可分离卷积

NVAE 的基础构建模块是专门设计的残差单元 (Residual Cell),这些单元格在编码器和解码器中都有使用。每个残差单元包含批量归一化 (Batch Normalization, BN)、Swish 激活函数和深度可分离卷积 (Depth-wise Separable Convolutions) 等组件。这种设计不仅保证了数值稳定性,还显著减少了参数量。NVAE 中所用残差单元如下图所示。

残差单元

深度可分离卷积是 NVAE 中的关键技术创新之一。与标准卷积相比,深度可分离卷积将卷积操作分解为两个步骤:深度卷积(对每个输入通道单独进行空间卷积)和逐点卷积( 1×1 卷积,用于组合通道信息)。这种分解大幅减少了计算复杂度和参数数量,使模型能够快速扩大感受野而不受计算资源的限制。
NVAE 中还采用了挤压和激励 (Squeeze-and-Excitation, SE) 模块来增强模型的表示能力。SE 模块通过自适应地重新校准通道特征响应,使模型能够关注最信息丰富的特征。这一机制与残差单元格的结合进一步提升了模型对重要特征的敏感性。

3. 残差参数化与后验分布

NVAE 提出了残差参数化 (residual parameterization) 方法来改进近似后验分布的表达能力。在传统 VAE 中,近似后验分布通常被假设为对角协方差高斯分布,这种假设限制了模型的表达能力。NVAE 通过残差参数化放松了这一限制,允许更灵活的后验分布形式。
具体而言,对于每个层次的潜在变量,其均值和方差不是直接从网络输出计算得到,而是基于先前层次的残差更新来计算。这种方法确保了潜在变量之间的依赖性,同时保持了计算的可处理性。数学上,这种参数化方式使得 KL 散度项可以解析计算,避免了复杂的近似方法。
残差参数化还与条件先验 (conditional prior) 的概念紧密结合。在生成过程中,每个层次的先验分布不仅依赖于先前潜在变量,还依赖于自上而下的网络传递的信息。这种设计使先验分布更加丰富和表达力强,有助于生成更高质量的样本。

4. 使用 PyTorch 构建 NVAE

接下来,我们将使用 Celeb A 人脸图像数据集构建 NVAE

4.1 数据集加载

(1) 首先,导入所需库:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torchvision.transforms as transforms
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader, Dataset
import torchvision.utils as vutils
import numpy as np
from PIL import Image
from glob import glob
import torchvision

(2) 定义图像预处理变换:

image_size = 64
transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将图像归一化到[-1, 1]范围

(3) 创建数据集和数据加载器:

batch_size = 32
ds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8)

4.2 模型构建与训练

(1) 实现残差单元,用于构建 NVAE的编码器和解码器,使用深度可分离卷积提高效率:

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个卷积层self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 快捷连接self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))# 应用谱归一化以稳定训练self.conv1 = spectral_norm(self.conv1)self.conv2 = spectral_norm(self.conv2)if len(self.shortcut) > 0:self.shortcut[0] = spectral_norm(self.shortcut[0])def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = F.relu(out)return out

(2) 定义 NVAE 编码器块,包含多个残差单元和下采样操作,每个编码器块处理特定分辨率的特征:

class EncoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(EncoderBlock, self).__init__()self.blocks = nn.ModuleList()# 第一个块进行下采样self.blocks.append(ResidualBlock(in_channels, out_channels, stride=stride))# 添加额外的残差块(不下采样)for _ in range(1, num_blocks):self.blocks.append(ResidualBlock(out_channels, out_channels, stride=1))def forward(self, x):for block in self.blocks:x = block(x)return x

(3) 定义 NVAE 解码器块,包含多个残差块和上采样操作,每个解码器块重建特定分辨率的特征:

class DecoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(DecoderBlock, self).__init__()self.blocks = nn.ModuleList()# 添加残差块for _ in range(num_blocks - 1):self.blocks.append(ResidualBlock(in_channels, in_channels, stride=1))# 最后一个块进行上采样if stride > 1:self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU())else:self.upsample = nn.Identity()def forward(self, x):for block in self.blocks:x = block(x)x = self.upsample(x)return x

(4) 构建完整的 NVAE 模型,包含分层编码器和解码器,采用多尺度潜空间结构:

class NVAE(nn.Module):def __init__(self, image_channels=3, latent_dim=128, num_layers=4):super(NVAE, self).__init__()self.latent_dim = latent_dimself.num_layers = num_layers# 初始卷积层self.initial_conv = nn.Sequential(nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU())# 编码器层级self.enc_blocks = nn.ModuleList()self.enc_blocks.append(EncoderBlock(32, 64, num_blocks=2, stride=2))  # 64x64 -> 32x32self.enc_blocks.append(EncoderBlock(64, 128, num_blocks=2, stride=2)) # 32x32 -> 16x16self.enc_blocks.append(EncoderBlock(128, 256, num_blocks=2, stride=2)) # 16x16 -> 8x8self.enc_blocks.append(EncoderBlock(256, 512, num_blocks=2, stride=2)) # 8x8 -> 4x4# 潜在空间均值和对数方差预测self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)# 解码器层级self.dec_blocks = nn.ModuleList()self.dec_blocks.append(nn.Sequential(nn.Linear(latent_dim, 512 * 4 * 4),nn.ReLU()))self.dec_blocks.append(DecoderBlock(512, 256, num_blocks=2, stride=2))  # 4x4 -> 8x8self.dec_blocks.append(DecoderBlock(256, 128, num_blocks=2, stride=2))  # 8x8 -> 16x16self.dec_blocks.append(DecoderBlock(128, 64, num_blocks=2, stride=2))   # 16x16 -> 32x32self.dec_blocks.append(DecoderBlock(64, 32, num_blocks=2, stride=2))    # 32x32 -> 64x64# 最终重建层self.final_conv = nn.Sequential(nn.Conv2d(32, image_channels, kernel_size=3, stride=1, padding=1),nn.Tanh()  # 输出范围[-1, 1],与输入一致)def encode(self, x):# 编码输入图像,返回潜分布的均值和方差x = self.initial_conv(x)for block in self.enc_blocks:x = block(x)x = x.view(x.size(0), -1)  # 展平mu = self.fc_mu(x)logvar = self.fc_logvar(x)return mu, logvardef reparameterize(self, mu, logvar):# 重参数化技巧,从潜分布中采样std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):# 从潜变量解码重建图像x = self.dec_blocks[0](z)x = x.view(-1, 512, 4, 4)  # 重塑为特征图for i in range(1, len(self.dec_blocks)):x = self.dec_blocks[i](x)x = self.final_conv(x)return xdef forward(self, x):# 前向传播:编码输入图像,采样潜变量,然后解码重建mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)recon_x = self.decode(z)return recon_x, mu, logvardef sample(self, num_samples, device):# 从潜空间采样生成新图像z = torch.randn(num_samples, self.latent_dim).to(device)samples = self.decode(z)return samples

(5) 定义损失函数,结合重建损失和 KL 散度

def loss_function(recon_x, x, mu, logvar, beta=1.0):# 重建损失(使用均方误差)recon_loss = F.mse_loss(recon_x, x, reduction='sum')# KL散度损失kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# 总损失total_loss = recon_loss + beta * kld_lossreturn total_loss, recon_loss, kld_loss

(6) 定义设备,并实例化模型和优化器:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")model = NVAE(image_channels=3, latent_dim=256, num_layers=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

(7) 定义模型训练和验证函数:

def train(model, dataloader, optimizer, epoch, device, beta=1.0):model.train()train_loss = 0recon_loss = 0kld_loss = 0for batch_idx, (data, _) in enumerate(dataloader):data = data.to(device)optimizer.zero_grad()# 前向传播recon_batch, mu, logvar = model(data)# 计算损失loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)# 反向传播loss.backward()# 梯度裁剪(防止梯度爆炸)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()train_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()if batch_idx % 100 == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)} 'f'({100. * batch_idx / len(dataloader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')# 计算平均损失avg_loss = train_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> Epoch: {epoch} 平均损失: {avg_loss:.4f} 'f'平均重建损失: {avg_recon:.4f} 平均KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_klddef validate(model, dataloader, device, beta=1.0):model.eval()val_loss = 0recon_loss = 0kld_loss = 0with torch.no_grad():for i, (data, _) in enumerate(dataloader):data = data.to(device)recon_batch, mu, logvar = model(data)loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)val_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()avg_loss = val_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> 验证集损失: {avg_loss:.4f} 'f'重建损失: {avg_recon:.4f} KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_kld

(8) 训练模型:

# 开始训练
num_epochs = 500
beta = 0.1  # KL散度的权重系数train_losses = []
val_losses = []for epoch in range(1, num_epochs + 1):train_loss, train_recon, train_kld = train(model, dataloader, optimizer, epoch, device, beta)val_loss, val_recon, val_kld = validate(model, dataloader, device, beta)train_losses.append(train_loss)val_losses.append(val_loss)# 更新学习率scheduler.step()# 每10个epoch保存一次模型和样本if epoch % 2 == 0:torch.save(model.state_dict(), f'nvae_celeba_epoch_{epoch}.pth')# 生成并保存样本with torch.no_grad():sample = torch.randn(16, 256).to(device)sample = model.decode(sample).cpu()vutils.save_image(sample, f'sample_epoch_{epoch}.png', nrow=4, normalize=True)# 保存重建示例test_iter = iter(dataloader)test_data, _ = next(test_iter)test_data = test_data.to(device)with torch.no_grad():recon_data, _, _ = model(test_data)comparison = torch.cat([test_data[:8], recon_data[:8]]).cpu()vutils.save_image(comparison, f'reconstruction_epoch_{epoch}.png', nrow=8, normalize=True)# 保存最终模型
torch.save(model.state_dict(), 'nvae_celeba_final.pth')

模型训练过程,模型重建效果如下所示,可以看到重建效果随着训练的进行不断得到改进:

模型训练过程

接下来,查看模型训练完成后的生成图像:

生成图像

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能——VQ-VAE详解与实现

http://www.xdnf.cn/news/1474417.html

相关文章:

  • 贪心算法应用:基因编辑靶点选择问题详解
  • 【C++】类和对象(三)
  • Git reset 回退版本
  • stunnel实现TCP双向认证加密
  • Custom SRP - Complex Maps
  • 顺丰,途虎养车,优博讯,得物,作业帮,途游游戏,三七互娱,汤臣倍健,游卡,快手26届秋招内推
  • JVM如何排查OOM
  • 01.单例模式基类模块
  • 微信小程序携带token跳转h5, h5再返回微信小程序
  • Knative Serving:ABP 应用的 scale-to-zero 与并发模型
  • 【Python 】入门:安装教程+入门语法
  • 使用 C# .NETCore 实现MongoDB
  • OpenAI新论文:Why Language Models Hallucinate
  • 【黑客技术零基础入门】2W字零基础小白黑客学习路线,知识体系(附学习路线图)
  • 【C++】C++11的可变参数模板、emplace接口、类的新功能
  • 《云原生微服务治理进阶:隐性风险根除与全链路能力构建》
  • 旧电脑改造服务器1:启动盘制作
  • Element-Plus
  • Nestjs框架: 基于权限的精细化权限控制方案与 CASL 在 Node.js 中的应用实践
  • 【Mysql-installer-community-8.0.26.0】Mysql 社区版(8.0.26.0) 在Window 系统的默认安装配置
  • Nikto 漏洞扫描工具使用指南
  • 管家婆辉煌系列软件多仓库出库操作指南
  • Kubernetes (k8s)
  • MySQL连接字符串中的安全与性能参数详解
  • Monorepo 是什么?如何使用并写自己的第三方库
  • 聊聊OAuth2.0和OIDC
  • 音转文模型对比FunASR与Faster_whisper
  • 《sklearn机器学习——聚类性能指标》Contingency Matrix(列联表)详解
  • PlantSimulation 在汽车总装车间配送物流仿真中的应用
  • Fantasia3D:高质量文本到3D内容创建工具