当前位置: 首页 > news >正文

PyTorch生成式人工智能——VQ-VAE详解与实现

PyTorch生成式人工智能——VQ-VAE详解与实现

    • 0. 前言
    • 1. VQ-VAE 技术原理
      • 1.1 引入离散潜变量
      • 1.2 向量量化
      • 1.3 损失函数
      • 1.4 指数滑动平均
      • 1.5 梯度直通 (Straight-Through)
    • 2. VQ-VAE 网络架构
    • 3. 实现 VQ-VAE
      • 3.1 模型构建
      • 3.2 模型训练
    • 相关链接

0. 前言

在传统的变分自编码器 (Variational Auto-Encoder, VAE) 中,模型学习的是一个连续的潜在表示 (latent representation)。然而,对于许多模态的数据,如语言、语音或某些图像特征,离散的潜表示往往更加自然有效。
VQ-VAE (Vector Quantised-Variational AutoEncoder) 的核心思想就是将 VAE 的连续潜变量离散化。它通过学习一个码本 (Codebook) 来实现这一点,码本是一个包含有限个嵌入向量的字典。模型不是直接输出一个连续的潜在向量,而是从码本中找出与编码器输出最接近的嵌入向量来代替它。这种离散化带来了以下优势:

  • 兼容性:离散的潜在空间可以很自然地与自回归模型(如 PixelCNN、Transformer )结合,用于强大的先验建模,从而生成高质量的新样本
  • 计算效率:对于下游任务,处理离散的 token 通常比处理连续的向量更高效
  • 可解释性:码本中的每个向量可以看作是学习到的一种“视觉单词”或基本概念

本节首先详细讲解 VQ-VAE 的技术原理,然后使用 PyTorch 从零开始实现 VQ-VAE 模型。

1. VQ-VAE 技术原理

1.1 引入离散潜变量

传统的变分自编码器 (Variational Auto-Encoder, VAE) 的潜变量 zzz 连续且服从高斯先验,重建质量与生成保真度在某些任务上(如语音、图像纹理)不够理想。VQ-VAE 通过码本 (Codebook) 引入离散潜变量:编码器将输入映射到隐空间连续向量 ze(x)z_e(x)ze(x),随后用最近邻查找从码本 {ek}k=1K\{e_k\}_{k=1}^K{ek}k=1K 选出索引 kkk,离散化为 zq(x)=ekz_q(x)=e^kzq(x)=ek,再由解码器重建。优势在于:

  • 信息瓶颈自然离散化:离散索引更像“符号化”的语义单元
  • 便于后续建模:可以对索引序列用自回归(如 PixelCNN、Transformer )建立先验模型

1.2 向量量化

对每个位置的连续表示 zez_eze,选择使欧氏距离最小的码本向量:
k=arg⁡mink∣∣ze−ek∣∣2,zq=ekk=\underset k{arg⁡min}||z_e−e_k||_2,z_q=e_k k=kargmin∣∣zeek2,zq=ek

1.3 损失函数

VQ-VAE 的核心损失包含三部分:
L=∣∣x−x^∣∣1⏟reconstruction+∣∣sg[ze]−e∣∣22⏟codebook+β∣∣ze−sg[e]∣∣22⏟commitment\mathcal L=\underbrace {||x−\hat x||_1}_{reconstruction} +  \underbrace {||sg[z_e]−e||_2^2}_{codebook}  +  \underbrace {β||z_e−sg[e]||_2^2}_{commitment} L=reconstruction∣∣xx^1+  codebook∣∣sg[ze]e22  +  commitmentβ∣∣zesg[e]22
其中:

  • sg[⋅]sg[\cdot]sg[] (stop gradient) 阻止梯度回传,即在反向传播时将其视为常数
  • 第一部分为重构损失 (Reconstruction Loss),用于最小化输入与重构的差异
  • 第二部分为码本损失 (Codebook Loss),让码本向量 eee 向编码器输出 zez_eze 靠近,通常使用 L2 损失,stop gradient 操作作用在编码器上,因此这个损失只更新码本,不更新编码器
  • 第三部分为 Commitment Loss,让编码器的输出 zez_eze 向选中的码本向量 eee 靠近,防止编码器的输出在码本空间内随意波动,通常使用 L2 损失,stop gradient 作用在码本向量上,因此这个损失只更新编码器,不更新码本
  • βββ 是一个超参数,通常取 0.25–0.5

1.4 指数滑动平均

指数滑动平均 (Exponential Moving Average, EMA) 使用聚类视角更新码本参数,可以避免显式的码本惩罚 (codebook-penalty):
Nk←γNk+(1−γ)⋅countkmk←γmk+(1−γ)⋅∑i∈kze,iek←mkNk+ϵN_k\leftarrow\gamma N_k+(1-\gamma)\cdot count_k\\ m_k\leftarrow\gamma m_k+(1-\gamma)\cdot \sum_{i\in k}z_{e,i}\\ e_k\leftarrow \frac{m_k}{N_k+\epsilon} NkγNk+(1γ)countkmkγmk+(1γ)ikze,iekNk+ϵmk
只保留 Reconstruction LossCommitment Loss,收敛更稳,码本利用率更好。

1.5 梯度直通 (Straight-Through)

量化过程 (argmin) 是不可导的,这阻碍了梯度从解码器传回编码器。VQ-VAE采用了一个巧妙的技巧:在反向传播时,直接将解码器关于 z_q 的梯度 ∂L/∂zq∂\mathcal L/∂z_qL/zq 复制给编码器的输出 zez_eze。即:
∂L/∂ze=∂L/∂zq∂\mathcal L/∂z_e = ∂\mathcal L/∂z_q L/ze=L/zq
这样,虽然量化操作本身没有梯度,但编码器仍然可以接收到来自解码器的梯度信号并进行更新。

2. VQ-VAE 网络架构

VQ-VAE 的整体结构如下图所示,其核心是一个由编码器、码本、解码器组成的架构。
网络架构

网络训练流程如下:

  • 编码器生成 zez_eze
  • 量化层选择最近码本向量 zq=ekz_q = e_kzq=ek
  • 解码器重构 x^\hat{x}x^
  • 计算总损失并反向传播,更新编码器、解码器和码本。

3. 实现 VQ-VAE

了解了 VQ-VAE 的核心原理和训练流程后,接下来,使用 PyTorch 从零开始实现 VQ-VAE 模型。

3.1 模型构建

(1) 首先,导入所需库,并定义命令行参数解析函数:

# vqvae.py
import os, math, random, argparse, time
from pathlib import Pathimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torch import ampdef save_image_grid(tensors, filename, nrow=8):grid = utils.make_grid(torch.clamp(tensors, -0.5, 0.5) + 0.5, nrow=nrow)os.makedirs(os.path.dirname(filename), exist_ok=True)utils.save_image(grid, filename)def parse_args():p = argparse.ArgumentParser()p.add_argument('--data', type=str, default='./data')p.add_argument('--epochs', type=int, default=30)p.add_argument('--batch_size', type=int, default=128)p.add_argument('--lr', type=float, default=2e-4)p.add_argument('--commit', type=float, default=0.25, help='commitment beta')p.add_argument('--ema', action='store_true', help='use EMA codebook')p.add_argument('--codebook_size', type=int, default=512)p.add_argument('--embed_dim', type=int, default=64, help='latent channel dim (before quantize)')p.add_argument('--levels', type=int, default=1, help='number of VQ levels (this demo uses 1)')p.add_argument('--ckpt', type=str, default='./checkpoints/vqvae_best.pt')p.add_argument('--eval_only', action='store_true')p.add_argument('--amp', action='store_true', help='enable AMP mixed precision')return p.parse_args()

(2) 定义残差块、编码器与解码器:

class ResBlock(nn.Module):def __init__(self, c):super().__init__()self.net = nn.Sequential(nn.ReLU(),nn.Conv2d(c, c, 3, padding=1),nn.ReLU(),nn.Conv2d(c, c, 1))def forward(self, x):return x + self.net(x)class Encoder(nn.Module):def __init__(self, in_channels=3, hidden=128, embed_dim=64):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channels, hidden//2, 4, stride=2, padding=1), # 32->16nn.ReLU(),nn.Conv2d(hidden//2, hidden, 4, stride=2, padding=1),      # 16->8nn.ReLU(),ResBlock(hidden),ResBlock(hidden),nn.ReLU(),nn.Conv2d(hidden, embed_dim, 1)  # project to z_e (C=embed_dim))def forward(self, x):return self.net(x)class Decoder(nn.Module):def __init__(self, out_channels=3, hidden=128, embed_dim=64):super().__init__()self.net = nn.Sequential(nn.Conv2d(embed_dim, hidden, 3, padding=1),ResBlock(hidden),ResBlock(hidden),nn.ReLU(),nn.ConvTranspose2d(hidden, hidden//2, 4, stride=2, padding=1), # 8->16nn.ReLU(),nn.ConvTranspose2d(hidden//2, out_channels, 4, stride=2, padding=1), # 16->32nn.Tanh()  # output in [-1,1])

(3) 实现标准向量量化器,将编码器输出的连续向量映射到码本中最近的离散嵌入向量:

class VectorQuantizer(nn.Module):def __init__(self, codebook_size=512, embed_dim=64, beta=0.25):super().__init__()self.codebook_size = codebook_sizeself.embed_dim = embed_dimself.beta = betaself.embedding = nn.Embedding(codebook_size, embed_dim)nn.init.uniform_(self.embedding.weight, -1.0 / codebook_size, 1.0 / codebook_size)@torch.no_grad()def _nearest_indices(self, z_e):# z_e: (B,C,H,W) -> (BHW,C)z = z_e.permute(0,2,3,1).contiguous().view(-1, self.embed_dim)# distances: |z|^2 + |e|^2 - 2 z e^Te = self.embedding.weight  # (K,C)z2 = (z ** 2).sum(dim=1, keepdim=True)  # (N,1)e2 = (e ** 2).sum(dim=1)                # (K,)# (N,K)distances = z2 + e2.unsqueeze(0) - 2.0 * z @ e.t()indices = distances.argmin(dim=1)       # (N,)return indicesdef forward(self, z_e):B, C, H, W = z_e.shapewith torch.no_grad():indices = self._nearest_indices(z_e)   # (BHW,)# straight-through estimatorz_q = self.embedding(indices).view(B, H, W, C).permute(0,3,1,2).contiguous()# codebook + commitment losses# codebook: ||sg[z_e] - e||^2 -> pull e toward z_eloss_cb = F.mse_loss(z_q.detach(), z_e)# commitment: beta * ||z_e - sg[e]||^2 -> pull z_e toward eloss_commit = F.mse_loss(z_e, z_q.detach())loss_vq = loss_cb + self.beta * loss_commit# straight-through: copy gradientsz_q = z_e + (z_q - z_e).detach()# perplexitywith torch.no_grad():one_hot = F.one_hot(indices, num_classes=self.codebook_size).float()probs = one_hot.mean(dim=0)perplexity = torch.exp(-(probs * (probs + 1e-10).log()).sum())return z_q, loss_vq, perplexity, indices.view(B, H, W)

(4) 定义数滑动平均 (Exponential Moving Average, EMA) 向量量化器,用 EMA 聚类更新码本:

class VectorQuantizerEMA(nn.Module):def __init__(self, codebook_size=512, embed_dim=64, beta=0.25, decay=0.99, eps=1e-5):super().__init__()self.codebook_size = codebook_sizeself.embed_dim = embed_dimself.beta = betaself.decay = decayself.eps = epsembed = torch.randn(codebook_size, embed_dim) * 0.1self.register_buffer('embedding', embed)self.register_buffer('ema_cluster_size', torch.zeros(codebook_size))self.register_buffer('ema_embed', embed.clone())@torch.no_grad()def _nearest_indices(self, z_e):z = z_e.permute(0,2,3,1).contiguous().view(-1, self.embed_dim)  # (N,C)e = self.embedding  # (K,C)z2 = (z ** 2).sum(dim=1, keepdim=True)e2 = (e ** 2).sum(dim=1)distances = z2 + e2.unsqueeze(0) - 2.0 * z @ e.t()return distances.argmin(dim=1)def forward(self, z_e):B, C, H, W = z_e.shapewith torch.no_grad():indices = self._nearest_indices(z_e)  # (BHW,)one_hot = F.one_hot(indices, num_classes=self.codebook_size).float()  # (N,K)cluster_size = one_hot.sum(dim=0)  # (K,)# EMA updatesself.ema_cluster_size.mul_(self.decay).add_(cluster_size, alpha=1 - self.decay)z_sum = (z_e.permute(0,2,3,1).contiguous().view(-1, C).unsqueeze(2) * one_hot.unsqueeze(1)).sum(dim=0)  # (C,K)self.ema_embed.mul_(self.decay).add_(z_sum.t(), alpha=1 - self.decay)  # (K,C)n = self.ema_cluster_size.sum()cluster_size = (self.ema_cluster_size + self.eps) / (n + self.codebook_size * self.eps) * nembed_normalized = self.ema_embed / cluster_size.unsqueeze(1)self.embedding.copy_(embed_normalized)z_q = self.embedding[indices].view(B, H, W, C).permute(0,3,1,2).contiguous()# only commitment termloss_commit = F.mse_loss(z_e, z_q.detach())loss_vq = self.beta * loss_commit# straight-throughz_q = z_e + (z_q - z_e).detach()with torch.no_grad():probs = one_hot.mean(dim=0)perplexity = torch.exp(-(probs * (probs + 1e-10).log()).sum())return z_q, loss_vq, perplexity, indices.view(B, H, W)

(5) 将编码器、量化器与解码器组合为端到端模型:

class VQVAE(nn.Module):def __init__(self, in_channels=3, hidden=128, embed_dim=64, codebook_size=512, beta=0.25, use_ema=False):super().__init__()self.encoder = Encoder(in_channels, hidden, embed_dim)if use_ema:self.quantizer = VectorQuantizerEMA(codebook_size, embed_dim, beta=beta)else:self.quantizer = VectorQuantizer(codebook_size, embed_dim, beta=beta)self.decoder = Decoder(in_channels, hidden, embed_dim)def forward(self, x):z_e = self.encoder(x)                          # (B, C=embed_dim, H=8, W=8)z_q, loss_vq, perplexity, indices = self.quantizer(z_e)x_hat = self.decoder(z_q)return x_hat, loss_vq, perplexity, indices

3.2 模型训练

接下来,使用 CIFAR-10 数据集训练模型。
(1) 定义数据集加载函数:

def get_dataloaders(data_root, batch_size):tfm = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(0.5, 0.5)  # for all 3 channels: (x-0.5)/0.5 -> [-1,1]])tfm_val = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])train_set = datasets.CIFAR10(data_root, train=True, download=True, transform=tfm)val_set = datasets.CIFAR10(data_root, train=False, download=True, transform=tfm_val)train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)return train_loader, val_loader

(2) 定义训练与验证循环,使用 L1 重建损失,图像更锐利,我们也可以使用 L2 损失观察图像重建效果:

def train_one_epoch(model, loader, opt, scaler, device, use_amp=True):model.train()rec_loss_meter, vq_loss_meter, ppl_meter = 0.0, 0.0, 0.0n = 0for x, _ in loader:x = x.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)with amp.autocast("cuda", enabled=use_amp):x_hat, vq_loss, perplexity, _ = model(x)rec_loss = F.l1_loss(x_hat, x)loss = rec_loss + vq_lossscaler.scale(loss).backward()scaler.step(opt)scaler.update()bs = x.size(0)rec_loss_meter += rec_loss.item() * bsvq_loss_meter += vq_loss.item() * bsppl_meter += perplexity.item() * bsn += bsreturn rec_loss_meter/n, vq_loss_meter/n, ppl_meter/n@torch.no_grad()
def evaluate(model, loader, device, save_samples_path=None, max_batches=1):model.eval()rec_loss_meter, vq_loss_meter, ppl_meter = 0.0, 0.0, 0.0n = 0saved = Falsefor i, (x, _) in enumerate(loader):x = x.to(device, non_blocking=True)x_hat, vq_loss, perplexity, _ = model(x)rec_loss = F.l1_loss(x_hat, x)bs = x.size(0)rec_loss_meter += rec_loss.item() * bsvq_loss_meter += vq_loss.item() * bsppl_meter += perplexity.item() * bsn += bsif (save_samples_path is not None) and (not saved):# 拼接输入/重建cat = torch.cat([x[:32], x_hat[:32]], dim=0).detach().cpu()save_image_grid(cat, save_samples_path, nrow=8)saved = Trueif i+1 >= max_batches:breakreturn rec_loss_meter/n, vq_loss_meter/n, ppl_meter/n

(3) 组织训练/评估流程、保存最优权重与重建结果:

def main():args = parse_args()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f'Using device: {device}')train_loader, val_loader = get_dataloaders(args.data, args.batch_size)model = VQVAE(in_channels=3,hidden=128,embed_dim=args.embed_dim,codebook_size=args.codebook_size,beta=args.commit,use_ema=args.ema).to(device)opt = torch.optim.Adam(model.parameters(), lr=args.lr)scaler = amp.GradScaler("cuda", enabled=args.amp)best_val = float('inf')ckpt_dir = Path(args.ckpt).parentckpt_dir.mkdir(parents=True, exist_ok=True)if args.eval_only and os.path.isfile(args.ckpt):print(f'Loading checkpoint: {args.ckpt}')state = torch.load(args.ckpt, map_location=device)model.load_state_dict(state['model'])rec, vq, ppl = evaluate(model, val_loader, device, save_samples_path='./samples/recon_eval.png')print(f'[Eval] rec={rec:.4f} vq={vq:.4f} ppl={ppl:.2f}')returnfor epoch in range(1, args.epochs+1):t0 = time.time()tr_rec, tr_vq, tr_ppl = train_one_epoch(model, train_loader, opt, scaler, device, use_amp=args.amp)val_rec, val_vq, val_ppl = evaluate(model, val_loader, device, save_samples_path=f'./samples/recon_epoch_{epoch:03d}.png')elapsed = time.time() - t0val_total = val_rec + val_vqprint(f'Epoch {epoch:03d} | {elapsed:.1f}s | 'f'train rec={tr_rec:.4f} vq={tr_vq:.4f} ppl={tr_ppl:.2f} || 'f'val rec={val_rec:.4f} vq={val_vq:.4f} ppl={val_ppl:.2f}')if val_total < best_val:best_val = val_totaltorch.save({'model': model.state_dict(),'args': vars(args)}, args.ckpt)print(f'  -> Saved best to {args.ckpt}')if __name__ == '__main__':main()

(4) 在命令行中使用以下命令运行模型训练过程:

# 训练(非 EMA)
python vqvae.py --epochs 20 --batch_size 128 --commit 0.25# 训练(EMA 码本)
python vqvae.py --epochs 20 --batch_size 128 --commit 0.25 --ema# 从已训练权重做重建可视化
python vqvae.py --eval_only --ckpt ./checkpoints/vqvae_best.pt

模型训练过程保存的重建图像如下所示,可以看到随着训练的进行,模型的重建效果逐步得到提升:

模型重建效果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能(26)——使用PyTorch构建GPT模型
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现

http://www.xdnf.cn/news/1352413.html

相关文章:

  • TypeScript 的泛型(Generics)作用理解
  • Kafka 概念与概述
  • 在TencentOS3上部署OpenTenBase:从入门到实战的完整指南
  • 【Java学习笔记】18.反射与注解的应用
  • 遥感机器学习入门实战教程|Sklearn案例⑧:评估指标(metrics)全解析
  • tcpdump命令打印抓包信息
  • 【golang】ORM框架操作数据库
  • 2-5.Python 编码基础 - 键盘输入
  • STM32CubeIDE V1.9.0下载资源链接
  • 醋酸镨:催化剂领域的璀璨新星
  • LangChain4J-基础(整合Spring、RAG、MCP、向量数据库、提示词、流式输出)
  • 信贷模型域——信贷获客模型(获客模型)
  • 温度对直线导轨的性能有哪些影响?
  • 小白向:Obsidian(Markdown语法学习)快速入门完全指南:从零开始构建你的第二大脑(免费好用的笔记软件的知识管理系统)、黑曜石笔记
  • 数字经济、全球化与5G催生域名新价值的逻辑与实践路径
  • 快速掌握Java非线性数据结构:树(二叉树、平衡二叉树、多路平衡树)、堆、图【算法必备】
  • vue3 - 组件间的传值
  • 【小沐学GIS】基于Godot绘制三维数字地球Earth(Godot)
  • 计算机网络 TLS握手中三个随机数详解
  • 【Golang】有关垃圾收集器的笔记
  • 语义通信高斯信道仿真代码
  • GaussDB 数据库架构师修炼(十八) SQL引擎-计划管理-SQL PATCH
  • Base64编码、AES加密、RSA加密、MD5加密
  • RAG Embeddings 向量数据库
  • 使用Ollama部署自己的本地模型
  • 疯狂星期四文案网第48天运营日记
  • 12 SQL进阶-锁(8.20)
  • Python语法速成课程(二)
  • 科技赋能,宁夏农业绘就塞上新“丰”景
  • 进程的概念:进程调度算法