audioLDM模型代码阅读(三)——变分自编码器VAE
distributions.py
先给出完整代码:
import torch
import numpy as npclass AbstractDistribution:def sample(self):raise NotImplementedError()def mode(self):raise NotImplementedError()class DiracDistribution(AbstractDistribution):def __init__(self, value):self.value = valuedef sample(self):return self.valuedef mode(self):return self.valueclass DiagonalGaussianDistribution(object):def __init__(self, parameters, deterministic=False):self.parameters = parametersself.mean, self.logvar = torch.chunk(parameters, 2, dim=1)self.logvar = torch.clamp(self.logvar, -30.0, 20.0)self.deterministic = deterministicself.std = torch.exp(0.5 * self.logvar)self.var = torch.exp(self.logvar)if self.deterministic:self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)def sample(self):x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)return xdef kl(self, other=None):if self.deterministic:return torch.Tensor([0.0])else:if other is None:return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,dim=[1, 2, 3],)else:return 0.5 * torch.mean(torch.pow(self.mean - other.mean, 2) / other.var+ self.var / other.var- 1.0- self.logvar+ other.logvar,dim=[1, 2, 3],)def nll(self, sample, dims=[1, 2, 3]):if self.deterministic:return torch.Tensor([0.0])logtwopi = np.log(2.0 * np.pi)return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,dim=dims,)def mode(self):return self.meandef normal_kl(mean1, logvar1, mean2, logvar2):"""source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12Compute the KL divergence between two gaussians.Shapes are automatically broadcasted, so batches can be compared toscalars, among other use cases."""tensor = Nonefor obj in (mean1, logvar1, mean2, logvar2):if isinstance(obj, torch.Tensor):tensor = objbreakassert tensor is not None, "at least one argument must be a Tensor"# Force variances to be Tensors. Broadcasting helps convert scalars to# Tensors, but it does not work for torch.exp().logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)for x in (logvar1, logvar2)]return 0.5 * (-1.0+ logvar2- logvar1+ torch.exp(logvar1 - logvar2)+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
这段代码主要实现了几种概率分布的类和相关工具函数,尤其专注于对角高斯分布(Diagonal Gaussian Distribution) 的实现,这类分布在变分自编码器(VAE)、扩散模型等生成式模型中广泛用于潜在变量的概率建模。以下是详细解析:
1. 抽象基类 AbstractDistribution
class AbstractDistribution:def sample(self):raise NotImplementedError()def mode(self):raise NotImplementedError()
- 作用:定义概率分布的抽象接口,规定所有具体分布类必须实现两个核心方法:
sample()
:从分布中采样一个样本。mode()
:返回分布的众数(概率密度最高的点)。
- 设计目的:统一不同分布的调用接口,方便后续在模型中替换或使用不同分布(如高斯分布、狄拉克分布等)。
2. 狄拉克分布 DiracDistribution
class DiracDistribution(AbstractDistribution):def __init__(self, value):self.value = value # 分布的唯一确定性值def sample(self):return self.value # 采样结果就是固定值def mode(self):return self.value # 众数也是固定值
- 数学背景:狄拉克分布(Dirac delta distribution)是一种确定性分布,所有概率质量集中在单个点上(即
value
)。 - 方法解析:
__init__
:接收一个固定值value
,作为分布的唯一可能值。sample()
和mode()
:均返回value
,因为狄拉克分布的采样结果和众数都是这个固定值。
- 应用场景:用于确定性模型(非概率模型),或作为概率模型的特殊情况(当分布方差为0时)。
3. 对角高斯分布 DiagonalGaussianDistribution
这是代码的核心类,实现了各维度独立的多元高斯分布(即协方差矩阵为对角矩阵),适用于高维潜在变量(如VAE中的潜在空间)。
3.1 初始化方法 __init__
def __init__(self, parameters, deterministic=False):self.parameters = parameters # 输入参数(包含均值和对数方差)# 将参数沿通道维度(dim=1)分为均值(mean)和对数方差(logvar)self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)# 截断logvar到[-30, 20],避免数值溢出(防止std过大或过小)self.logvar = torch.clamp(self.logvar, -30.0, 20.0)self.deterministic = deterministic # 是否为确定性模式(方差为0)# 计算标准差(std)和方差(var):std = exp(0.5 * logvar),var = exp(logvar)self.std = torch.exp(0.5 * self.logvar)self.var = torch.exp(self.logvar)# 若为确定性模式,强制方差和标准差为0(退化为狄拉克分布)if self.deterministic:self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
- 核心参数:
parameters
是一个张量,包含分布的均值和对数方差(沿dim=1
拼接,例如形状为(batch, 2*z_dim, ...)
,分块后得到mean
和logvar
,各为(batch, z_dim, ...)
)。 - 数值稳定性:
logvar
被截断到[-30, 20]
,因为:- 若
logvar > 20
,std = exp(10) ≈ 2万
,可能导致采样结果过大; - 若
logvar < -30
,std = exp(-15) ≈ 3e-7
,接近0,可能导致梯度消失。
- 若
- 确定性模式:当
deterministic=True
时,方差和标准差被置为0,分布退化为以mean
为中心的狄拉克分布(用于推理时避免随机性)。
3.2 采样方法 sample()
def sample(self):# 重参数化采样:mean + std * 标准正态噪声(确保梯度可传)x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)return x
- 核心思想:使用重参数化技巧(Reparameterization Trick),将采样过程表示为
mean + std * ε
(其中ε ~ N(0,1)
),使采样操作可导(避免因随机采样导致的梯度断裂)。 - 适用场景:训练时从潜在分布中采样,引入随机性以符合VAE的概率建模要求。
3.3 KL散度计算 kl()
KL散度(Kullback-Leibler Divergence)用于衡量两个分布的差异,是VAE中约束潜在分布的核心损失。
def kl(self, other=None):if self.deterministic:return torch.Tensor([0.0]) # 确定性分布的KL散度为0else:if other is None:# 与标准正态分布N(0,1)的KL散度return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,dim=[1, 2, 3], # 对通道、时间、频率等维度求均值)else:# 与另一个高斯分布other的KL散度return 0.5 * torch.mean(torch.pow(self.mean - other.mean, 2) / other.var+ self.var / other.var- 1.0- self.logvar+ other.logvar,dim=[1, 2, 3],)
- 数学公式:
两个高斯分布N(μ₁, σ₁²)
和N(μ₂, σ₂²)
的KL散度为:
KL = 0.5 * [ (μ₁-μ₂)²/σ₂² + σ₁²/σ₂² - 1 - log(σ₁²/σ₂²) ]
当other=None
时,默认与标准正态分布N(0,1)
比较,公式简化为:
KL = 0.5 * (μ₁² + σ₁² - 1 - logσ₁²)
- 维度处理:
dim=[1,2,3]
表示对潜在变量的通道、时间、频率等维度求和后取均值(适用于图像、音频等多维特征)。
3.4 负对数似然 nll()
def nll(self, sample, dims=[1, 2, 3]):if self.deterministic:return torch.Tensor([0.0]) # 确定性分布的负对数似然为0logtwopi = np.log(2.0 * np.pi) # log(2π)是常数# 高斯分布的负对数似然公式return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,dim=dims,)
- 数学背景:若样本
x
服从高斯分布N(μ, σ²)
,则其对数似然为:
log p(x) = -0.5 * [log(2π) + logσ² + (x-μ)²/σ²]
负对数似然(NLL)为上述值的相反数。 - 作用:衡量样本
sample
与当前分布的匹配程度,常用于生成模型的损失计算。
3.5 众数 mode()
def mode(self):return self.mean # 高斯分布的众数等于均值
- 高斯分布是单峰分布,其概率密度最高的点(众数)就是均值
mean
。 - 应用场景:推理时生成确定性结果(如不采样,直接用均值作为潜在变量)。
4. 高斯KL散度工具函数 normal_kl
def normal_kl(mean1, logvar1, mean2, logvar2):"""计算两个高斯分布N(mean1, exp(logvar1))和N(mean2, exp(logvar2))的KL散度"""tensor = Nonefor obj in (mean1, logvar1, mean2, logvar2):if isinstance(obj, torch.Tensor):tensor = objbreakassert tensor is not None, "至少一个参数必须是Tensor"# 确保logvar1和logvar2是Tensor(支持标量输入)logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)for x in (logvar1, logvar2)]# KL散度公式(逐元素计算)return 0.5 * (-1.0+ logvar2- logvar1+ torch.exp(logvar1 - logvar2)+ ((mean1 - mean2) **2) * torch.exp(-logvar2))
-** 功能 :通用的高斯KL散度计算函数,支持输入为张量或标量(通过广播机制适配不同形状)。
- 与 DiagonalGaussianDistribution.kl
的区别 :
后者是类方法,针对当前分布与另一个分布(或标准正态)计算KL散度并在指定维度求均值;
前者是独立函数,逐元素计算KL散度,不做均值或求和,更灵活(如用于需要保留空间维度的场景)。
- 来源 **:注释提到来自openai/guided-diffusion
,常用于扩散模型中的损失计算。
###** 总结 **- 这段代码围绕概率分布的核心操作(采样、众数、KL散度、负对数似然)构建了工具类,尤其聚焦于对角高斯分布——这是生成式模型(如VAE、扩散模型)中潜在变量建模的核心组件。
- 关键设计考虑:
- 接口统一:通过
AbstractDistribution
确保不同分布的调用一致性; - 数值稳定性:对
logvar
截断,避免标准差过大或过小; - 灵活性:支持确定性模式(推理用)和概率模式(训练用),并提供通用的KL散度计算函数。
- 接口统一:通过
这些工具为后续构建生成模型(如前文的AutoencoderKL
)提供了概率建模的基础,特别是在计算损失(如KL损失、重建损失)和潜在变量采样时不可或缺。
modules.py
先给出完整的代码:
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrangefrom audioldm.utils import instantiate_from_config
from audioldm.latent_diffusion.attention import LinearAttentiondef get_timestep_embedding(timesteps, embedding_dim):"""This matches the implementation in Denoising Diffusion Probabilistic Models:From Fairseq.Build sinusoidal embeddings.This matches the implementation in tensor2tensor, but differs slightlyfrom the description in Section 3.5 of "Attention Is All You Need"."""assert len(timesteps.shape) == 1half_dim = embedding_dim // 2emb = math.log(10000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)emb = emb.to(device=timesteps.device)emb = timesteps.float()[:, None] * emb[None, :]emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)if embedding_dim % 2 == 1: # zero pademb = torch.nn.functional.pad(emb, (0, 1, 0, 0))return embdef nonlinearity(x):# swishreturn x * torch.sigmoid(x)def Normalize(in_channels, num_groups=32):return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)class Upsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")if self.with_conv:x = self.conv(x)return xclass UpsampleTimeStride4(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=1, padding=2)def forward(self, x):x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")if self.with_conv:x = self.conv(x)return xclass Downsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:# Do time downsampling here# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)def forward(self, x):if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)return xclass DownsampleTimeStride4(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:# Do time downsampling here# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1)def forward(self, x):if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))return xclass ResnetBlock(nn.Module):def __init__(self,*,in_channels,out_channels=None,conv_shortcut=False,dropout,temb_channels=512,):super().__init__()self.in_channels = in_channelsout_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channelsself.use_conv_shortcut = conv_shortcutself.norm1 = Normalize(in_channels)self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)if temb_channels > 0:self.temb_proj = torch.nn.Linear(temb_channels, out_channels)self.norm2 = Normalize(out_channels)self.dropout = torch.nn.Dropout(dropout)self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)if self.in_channels != self.out_channels:if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)else:self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)def forward(self, x, temb):h = xh = self.norm1(h)h = nonlinearity(h)h = self.conv1(h)if temb is not None:h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]h = self.norm2(h)h = nonlinearity(h)h = self.dropout(h)h = self.conv2(h)if self.in_channels != self.out_channels:if self.use_conv_shortcut:x = self.conv_shortcut(x)else:x = self.nin_shortcut(x)return x + hclass LinAttnBlock(LinearAttention):"""to match AttnBlock usage"""def __init__(self, in_channels):super().__init__(dim=in_channels, heads=1, dim_head=in_channels)class AttnBlock(nn.Module):def __init__(self, in_channels):super().__init__()self.in_channels = in_channelsself.norm = Normalize(in_channels)self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)def forward(self, x):h_ = xh_ = self.norm(h_)q = self.q(h_)k = self.k(h_)v = self.v(h_)# compute attentionb, c, h, w = q.shapeq = q.reshape(b, c, h * w).contiguous()q = q.permute(0, 2, 1).contiguous() # b,hw,ck = k.reshape(b, c, h * w).contiguous() # b,c,hww_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]w_ = w_ * (int(c) ** (-0.5))w_ = torch.nn.functional.softmax(w_, dim=2)# attend to valuesv = v.reshape(b, c, h * w).contiguous()w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]h_ = h_.reshape(b, c, h, w).contiguous()h_ = self.proj_out(h_)return x + h_def make_attn(in_channels, attn_type="vanilla"):assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")if attn_type == "vanilla":return AttnBlock(in_channels)elif attn_type == "none":return nn.Identity(in_channels)else:return LinAttnBlock(in_channels)class Model(nn.Module):def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,use_timestep=True,use_linear_attn=False,attn_type="vanilla",):super().__init__()if use_linear_attn:attn_type = "linear"self.ch = chself.temb_ch = self.ch * 4self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.use_timestep = use_timestepif self.use_timestep:# timestep embeddingself.temb = nn.Module()self.temb.dense = nn.ModuleList([torch.nn.Linear(self.ch, self.temb_ch),torch.nn.Linear(self.temb_ch, self.temb_ch),])# downsamplingself.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)curr_res = resolutionin_ch_mult = (1,) + tuple(ch_mult)self.down = nn.ModuleList()for i_level in range(self.num_resolutions):block = nn.ModuleList()attn = nn.ModuleList()block_in = ch * in_ch_mult[i_level]block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))down = nn.Module()down.block = blockdown.attn = attnif i_level != self.num_resolutions - 1:down.downsample = Downsample(block_in, resamp_with_conv)curr_res = curr_res // 2self.down.append(down)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# upsamplingself.up = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()attn = nn.ModuleList()block_out = ch * ch_mult[i_level]skip_in = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks + 1):if i_block == self.num_res_blocks:skip_in = ch * in_ch_mult[i_level]block.append(ResnetBlock(in_channels=block_in + skip_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))up = nn.Module()up.block = blockup.attn = attnif i_level != 0:up.upsample = Upsample(block_in, resamp_with_conv)curr_res = curr_res * 2self.up.insert(0, up) # prepend to get consistent order# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)def forward(self, x, t=None, context=None):# assert x.shape[2] == x.shape[3] == self.resolutionif context is not None:# assume aligned context, cat along channel axisx = torch.cat((x, context), dim=1)if self.use_timestep:# timestep embeddingassert t is not Nonetemb = get_timestep_embedding(t, self.ch)temb = self.temb.dense[0](temb)temb = nonlinearity(temb)temb = self.temb.dense[1](temb)else:temb = None# downsamplinghs = [self.conv_in(x)]for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# middleh = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# upsamplingfor i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:h = self.up[i_level].upsample(h)# endh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return hdef get_last_layer(self):return self.conv_out.weightclass Encoder(nn.Module):def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,z_channels,double_z=True,use_linear_attn=False,attn_type="vanilla",downsample_time_stride4_levels=[],**ignore_kwargs,):super().__init__()if use_linear_attn:attn_type = "linear"self.ch = chself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.downsample_time_stride4_levels = downsample_time_stride4_levelsif len(self.downsample_time_stride4_levels) > 0:assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ("The level to perform downsample 4 operation need to be smaller than the total resolution number %s"% str(self.num_resolutions))# downsamplingself.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)curr_res = resolutionin_ch_mult = (1,) + tuple(ch_mult)self.in_ch_mult = in_ch_multself.down = nn.ModuleList()for i_level in range(self.num_resolutions):block = nn.ModuleList()attn = nn.ModuleList()block_in = ch * in_ch_mult[i_level]block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))down = nn.Module()down.block = blockdown.attn = attnif i_level != self.num_resolutions - 1:if i_level in self.downsample_time_stride4_levels:down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)else:down.downsample = Downsample(block_in, resamp_with_conv)curr_res = curr_res // 2self.down.append(down)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in,2 * z_channels if double_z else z_channels,kernel_size=3,stride=1,padding=1,)def forward(self, x):# timestep embeddingtemb = None# downsamplinghs = [self.conv_in(x)]for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# middleh = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# endh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return hclass Decoder(nn.Module):def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,z_channels,give_pre_end=False,tanh_out=False,use_linear_attn=False,downsample_time_stride4_levels=[],attn_type="vanilla",**ignorekwargs,):super().__init__()if use_linear_attn:attn_type = "linear"self.ch = chself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.give_pre_end = give_pre_endself.tanh_out = tanh_outself.downsample_time_stride4_levels = downsample_time_stride4_levelsif len(self.downsample_time_stride4_levels) > 0:assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ("The level to perform downsample 4 operation need to be smaller than the total resolution number %s"% str(self.num_resolutions))# compute in_ch_mult, block_in and curr_res at lowest resin_ch_mult = (1,) + tuple(ch_mult)block_in = ch * ch_mult[self.num_resolutions - 1]curr_res = resolution // 2 ** (self.num_resolutions - 1)self.z_shape = (1, z_channels, curr_res, curr_res)# print("Working with z of shape {} = {} dimensions.".format(# self.z_shape, np.prod(self.z_shape)))# z to block_inself.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# upsamplingself.up = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()attn = nn.ModuleList()block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks + 1):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))up = nn.Module()up.block = blockup.attn = attnif i_level != 0:if i_level - 1 in self.downsample_time_stride4_levels:up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)else:up.upsample = Upsample(block_in, resamp_with_conv)curr_res = curr_res * 2self.up.insert(0, up) # prepend to get consistent order# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)def forward(self, z):# assert z.shape[1:] == self.z_shape[1:]self.last_z_shape = z.shape# timestep embeddingtemb = None# z to block_inh = self.conv_in(z)# middleh = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# upsamplingfor i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](h, temb)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:h = self.up[i_level].upsample(h)# endif self.give_pre_end:return hh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)if self.tanh_out:h = torch.tanh(h)return hclass SimpleDecoder(nn.Module):def __init__(self, in_channels, out_channels, *args, **kwargs):super().__init__()self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),ResnetBlock(in_channels=in_channels,out_channels=2 * in_channels,temb_channels=0,dropout=0.0,),ResnetBlock(in_channels=2 * in_channels,out_channels=4 * in_channels,temb_channels=0,dropout=0.0,),ResnetBlock(in_channels=4 * in_channels,out_channels=2 * in_channels,temb_channels=0,dropout=0.0,),nn.Conv2d(2 * in_channels, in_channels, 1),Upsample(in_channels, with_conv=True),])# endself.norm_out = Normalize(in_channels)self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):for i, layer in enumerate(self.model):if i in [1, 2, 3]:x = layer(x, None)else:x = layer(x)h = self.norm_out(x)h = nonlinearity(h)x = self.conv_out(h)return xclass UpsampleDecoder(nn.Module):def __init__(self,in_channels,out_channels,ch,num_res_blocks,resolution,ch_mult=(2, 2),dropout=0.0,):super().__init__()# upsamplingself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksblock_in = in_channelscurr_res = resolution // 2 ** (self.num_resolutions - 1)self.res_blocks = nn.ModuleList()self.upsample_blocks = nn.ModuleList()for i_level in range(self.num_resolutions):res_block = []block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks + 1):res_block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outself.res_blocks.append(nn.ModuleList(res_block))if i_level != self.num_resolutions - 1:self.upsample_blocks.append(Upsample(block_in, True))curr_res = curr_res * 2# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):# upsamplingh = xfor k, i_level in enumerate(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.res_blocks[i_level][i_block](h, None)if i_level != self.num_resolutions - 1:h = self.upsample_blocks[k](h)h = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return hclass LatentRescaler(nn.Module):def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):super().__init__()# residual block, interpolate, residual blockself.factor = factorself.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,)for _ in range(depth)])self.attn = AttnBlock(mid_channels)self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,)for _ in range(depth)])self.conv_out = nn.Conv2d(mid_channels,out_channels,kernel_size=1,)def forward(self, x):x = self.conv_in(x)for block in self.res_block1:x = block(x, None)x = torch.nn.functional.interpolate(x,size=(int(round(x.shape[2] * self.factor)),int(round(x.shape[3] * self.factor)),),)x = self.attn(x).contiguous()for block in self.res_block2:x = block(x, None)x = self.conv_out(x)return xclass MergedRescaleEncoder(nn.Module):def __init__(self,in_channels,ch,resolution,out_ch,num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,ch_mult=(1, 2, 4, 8),rescale_factor=1.0,rescale_module_depth=1,):super().__init__()intermediate_chn = ch * ch_mult[-1]self.encoder = Encoder(in_channels=in_channels,num_res_blocks=num_res_blocks,ch=ch,ch_mult=ch_mult,z_channels=intermediate_chn,double_z=False,resolution=resolution,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,out_ch=None,)self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=intermediate_chn,mid_channels=intermediate_chn,out_channels=out_ch,depth=rescale_module_depth,)def forward(self, x):x = self.encoder(x)x = self.rescaler(x)return xclass MergedRescaleDecoder(nn.Module):def __init__(self,z_channels,out_ch,resolution,num_res_blocks,attn_resolutions,ch,ch_mult=(1, 2, 4, 8),dropout=0.0,resamp_with_conv=True,rescale_factor=1.0,rescale_module_depth=1,):super().__init__()tmp_chn = z_channels * ch_mult[-1]self.decoder = Decoder(out_ch=out_ch,z_channels=tmp_chn,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,in_channels=None,num_res_blocks=num_res_blocks,ch_mult=ch_mult,resolution=resolution,ch=ch,)self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=z_channels,mid_channels=tmp_chn,out_channels=tmp_chn,depth=rescale_module_depth,)def forward(self, x):x = self.rescaler(x)x = self.decoder(x)return xclass Upsampler(nn.Module):def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):super().__init__()assert out_size >= in_sizenum_blocks = int(np.log2(out_size // in_size)) + 1factor_up = 1.0 + (out_size % in_size)print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")self.rescaler = LatentRescaler(factor=factor_up,in_channels=in_channels,mid_channels=2 * in_channels,out_channels=in_channels,)self.decoder = Decoder(out_ch=out_channels,resolution=out_size,z_channels=in_channels,num_res_blocks=2,attn_resolutions=[],in_channels=None,ch=in_channels,ch_mult=[ch_mult for _ in range(num_blocks)],)def forward(self, x):x = self.rescaler(x)x = self.decoder(x)return xclass Resize(nn.Module):def __init__(self, in_channels=None, learned=False, mode="bilinear"):super().__init__()self.with_conv = learnedself.mode = modeif self.with_conv:print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")raise NotImplementedError()assert in_channels is not None# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)def forward(self, x, scale_factor=1.0):if scale_factor == 1.0:return xelse:x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)return xclass FirstStagePostProcessor(nn.Module):def __init__(self,ch_mult: list,in_channels,pretrained_model: nn.Module = None,reshape=False,n_channels=None,dropout=0.0,pretrained_config=None,):super().__init__()if pretrained_config is None:assert (pretrained_model is not None), 'Either "pretrained_model" or "pretrained_config" must not be None'self.pretrained_model = pretrained_modelelse:assert (pretrained_config is not None), 'Either "pretrained_model" or "pretrained_config" must not be None'self.instantiate_pretrained(pretrained_config)self.do_reshape = reshapeif n_channels is None:n_channels = self.pretrained_model.encoder.chself.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, stride=1, padding=1)blocks = []downs = []ch_in = n_channelsfor m in ch_mult:blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout))ch_in = m * n_channelsdowns.append(Downsample(ch_in, with_conv=False))self.model = nn.ModuleList(blocks)self.downsampler = nn.ModuleList(downs)def instantiate_pretrained(self, config):model = instantiate_from_config(config)self.pretrained_model = model.eval()# self.pretrained_model.train = Falsefor param in self.pretrained_model.parameters():param.requires_grad = False@torch.no_grad()def encode_with_pretrained(self, x):c = self.pretrained_model.encode(x)if isinstance(c, DiagonalGaussianDistribution):c = c.mode()return cdef forward(self, x):z_fs = self.encode_with_pretrained(x)z = self.proj_norm(z_fs)z = self.proj(z)z = nonlinearity(z)for submodel, downmodel in zip(self.model, self.downsampler):z = submodel(z, temb=None)z = downmodel(z)if self.do_reshape:z = rearrange(z, "b c h w -> b (h w) c")return z
这段代码是基于PyTorch实现的一套生成式模型组件,主要用于扩散模型(Diffusion Model)和变分自编码器(VAE)等生成式任务,尤其适配音频或图像等具有时空结构的数据(从audioldm
相关引用可推测更偏向音频处理)。代码包含基础网络模块、扩散模型主网络、编码器、解码器及各类辅助模块,整体设计围绕卷积神经网络(CNN)、残差连接、注意力机制和扩散过程的时间步处理展开。
一、核心工具函数与基础模块
核心工具函数与基础模块是构建深层生成模型(如扩散模型、VAE)的“乐高积木”,负责处理特征变换、分辨率调整、梯度传播和长距离依赖捕捉等基础任务。这些模块的设计直接影响模型的性能、稳定性和适用性(如适配音频/图像等不同数据类型)。以下是详细解析:
1. 时间步嵌入(get_timestep_embedding
)
扩散模型的核心组件,用于将离散的“时间步”(扩散过程中添加噪声的步骤)转换为连续向量,使模型能学习不同时间步的降噪规律。
def get_timestep_embedding(timesteps, embedding_dim):assert len(timesteps.shape) == 1 # 输入为1D张量(batch_size,)half_dim = embedding_dim // 2# 生成指数衰减的频率(类似Transformer位置编码)emb = math.log(10000) / (half_dim - 1) # 频率缩放因子emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) # 频率向量emb = emb.to(device=timesteps.device) # 转移到与时间步相同的设备# 时间步与频率相乘,得到基础嵌入emb = timesteps.float()[:, None] * emb[None, :] # 形状:(batch_size, half_dim)# 拼接正弦和余弦分量,增强表达能力emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # 形状:(batch_size, embedding_dim)# 若嵌入维度为奇数,补零以匹配维度if embedding_dim % 2 == 1:emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) # 补最后一维return emb
关键细节:
- 原理类似Transformer的位置编码,通过正弦/余弦函数将离散时间步映射到连续空间,确保“时间距离”与“嵌入向量距离”一致(如时间步1和2的嵌入比1和10更相似)。
- 输出维度为
embedding_dim
,后续会通过全连接层进一步加工,融入模型的每个残差块,使模型能感知当前处于扩散过程的哪个阶段。
2. 非线性激活与归一化
2.1 Swish激活函数(nonlinearity
)
def nonlinearity(x):# swish激活:x * sigmoid(x)return x * torch.sigmoid(x)
- 优点:相比ReLU,Swish在负值区域不直接截断,而是平滑衰减,有助于缓解梯度消失,尤其适合深层网络。
- 无需额外参数,计算高效,在生成模型中广泛替代ReLU。
2.2 组归一化(Normalize
)
def Normalize(in_channels, num_groups=32):return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- 原理:将输入特征的通道分成
num_groups
组,每组内独立计算均值和方差并归一化,最后通过可学习参数(affine=True
)缩放和平移。 - 优势:不受批量大小影响(避免BatchNorm在小批量时的不稳定),适合生成模型(通常批量较小)。
- 音频/图像特征通常通道数较多(如256、512),
num_groups=32
是经验值,平衡计算效率和归一化效果。
3. 上采样与下采样模块
用于调整特征图的空间/时间分辨率(如音频频谱的“时间步×频率 bins”维度),是多尺度特征学习的核心。
3.1 普通上采样(Upsample
)
class Upsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv: # 上采样后是否用卷积调整特征self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):# 近邻插值上采样,尺度放大2倍(适用于空间/时间维度)x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")if self.with_conv: # 上采样后用3x3卷积消除锯齿,增强特征表达x = self.conv(x)return x
- 作用:将特征图分辨率翻倍(如28×28→56×56),
with_conv=True
时通过卷积整合插值后的特征,避免高频噪声。
3.2 音频专用上采样(UpsampleTimeStride4
)
class UpsampleTimeStride4(nn.Module):def forward(self, x):# 时间维度放大4倍,频率维度放大2倍(适配音频频谱特性)x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")if self.with_conv:x = self.conv(x) # 5x5卷积更适合捕捉音频的长时相关性return x
- 设计动机:音频频谱的时间维度(如毫秒级)比频率维度(如Hz)需要更大的缩放倍数(4倍 vs 2倍),以恢复时间连续性。
3.3 普通下采样(Downsample
)
class Downsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:# 3x3卷积+步长2实现下采样(替代池化,保留更多特征)self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)def forward(self, x):if self.with_conv:# 不对称 padding(右、下各补1),确保下采样后尺寸为原图1/2pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:# 平均池化下采样(更简单,但可能丢失细节)x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)return x
- 作用:将特征图分辨率减半(如56×56→28×28),
with_conv=True
时用卷积下采样,比池化保留更多空间结构信息。
3.4 音频专用下采样(DownsampleTimeStride4
)
class DownsampleTimeStride4(nn.Module):def forward(self, x):if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)# 时间步长4,频率步长2,适配音频下采样需求x = self.conv(x) # 5x5卷积捕捉更大范围特征else:x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))return x
- 应用:与
UpsampleTimeStride4
对应,在编码阶段压缩音频频谱的时间维度(4倍)和频率维度(2倍),减少计算量。
4. 残差块(ResnetBlock
)
生成式模型的基础构建块,通过残差连接解决深层网络的梯度消失问题,同时支持融入时间步嵌入(扩散模型专用)。
class ResnetBlock(nn.Module):def __init__(self,*,in_channels,out_channels=None,conv_shortcut=False,dropout,temb_channels=512, # 时间步嵌入维度):super().__init__()self.in_channels = in_channelsout_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channelsself.use_conv_shortcut = conv_shortcut # 残差连接是否用卷积调整通道# 第一层:归一化+激活+卷积self.norm1 = Normalize(in_channels)self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)# 时间步嵌入投影(扩散模型用)if temb_channels > 0:self.temb_proj = torch.nn.Linear(temb_channels, out_channels)# 第二层:归一化+激活+dropout+卷积self.norm2 = Normalize(out_channels)self.dropout = torch.nn.Dropout(dropout)self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)# 残差连接(输入输出通道不同时需要调整)if self.in_channels != self.out_channels:if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)else:self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) # 1x1卷积调整通道def forward(self, x, temb):h = x # 残差分支# 第一层处理h = self.norm1(h) # 组归一化h = nonlinearity(h) # Swish激活h = self.conv1(h) # 3x3卷积升维/保持维度# 融入时间步嵌入(仅扩散模型,temb不为None时)if temb is not None:# 时间步嵌入先激活,再投影到当前通道数,最后广播到特征图维度h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]# 第二层处理h = self.norm2(h)h = nonlinearity(h)h = self.dropout(h) # 防止过拟合h = self.conv2(h)# 残差连接(输入x与残差分支h相加)if self.in_channels != self.out_channels:# 若通道不同,先调整x的通道x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)return x + h # 残差相加,保留原始特征并叠加新特征
核心作用:
- 残差连接(
x + h
)使梯度能直接从输出流回输入,解决深层网络梯度消失问题。 - 支持融入时间步嵌入(
temb
),让每个残差块都能感知当前扩散阶段,是扩散模型的关键设计。 - 两层卷积+归一化+激活的结构,能有效提取局部特征,同时通过
dropout
增强泛化能力。
5. 注意力模块
捕捉特征图中的长距离依赖(如音频的长时相关性、图像的全局结构),分为普通注意力和线性注意力。
5.1 普通自注意力(AttnBlock
)
class AttnBlock(nn.Module):def __init__(self, in_channels):super().__init__()self.in_channels = in_channelsself.norm = Normalize(in_channels) # 归一化# 1x1卷积生成查询(q)、键(k)、值(v)self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 注意力输出投影self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)def forward(self, x):h_ = x # 注意力分支h_ = self.norm(h_) # 归一化# 生成q、k、v(形状:(batch, channels, height, width))q = self.q(h_)k = self.k(h_)v = self.v(h_)# 展平空间维度(height×width → 序列长度N)b, c, h, w = q.shapeq = q.reshape(b, c, h * w).contiguous() # (b, c, N)q = q.permute(0, 2, 1).contiguous() # (b, N, c):每个位置的查询向量k = k.reshape(b, c, h * w).contiguous() # (b, c, N):每个位置的键向量# 计算注意力权重:q与k的点积,缩放避免数值过大w_ = torch.bmm(q, k).contiguous() # (b, N, N):每个位置对其他位置的注意力w_ = w_ * (int(c) ** -0.5) # 缩放因子:通道数的平方根倒数w_ = torch.nn.functional.softmax(w_, dim=2) # 归一化权重# 注意力加权求和:用权重w_对v加权v = v.reshape(b, c, h * w).contiguous() # (b, c, N):每个位置的值向量w_ = w_.permute(0, 2, 1).contiguous() # (b, N, N):转置权重用于加权h_ = torch.bmm(v, w_).contiguous() # (b, c, N):加权后的值h_ = h_.reshape(b, c, h, w).contiguous() # 恢复空间维度# 投影输出+残差连接h_ = self.proj_out(h_)return x + h_ # 原始特征与注意力特征相加
关键细节:
- 本质是“空间自注意力”,将特征图的每个空间位置(如音频频谱的每个“时间-频率点”)视为序列元素,计算位置间的依赖关系。
- 复杂度为O(N2)O(N^2)O(N2)(NNN为空间位置数),适合低分辨率特征图(如32×32),高分辨率时计算成本过高。
5.2 线性注意力(LinAttnBlock
)
class LinAttnBlock(LinearAttention):"""to match AttnBlock usage"""def __init__(self, in_channels):super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
- 继承自
LinearAttention
(外部实现),通过核函数近似将注意力复杂度从O(N2)O(N^2)O(N2)降至O(N)O(N)O(N),适合高分辨率特征图。 - 牺牲部分表达能力换取效率,在对长距离依赖要求不极致的场景(如音频局部结构)中表现良好。
5.3 注意力选择器(make_attn
)
def make_attn(in_channels, attn_type="vanilla"):assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"if attn_type == "vanilla":return AttnBlock(in_channels) # 普通注意力elif attn_type == "none":return nn.Identity(in_channels) # 无注意力(仅残差连接)else:return LinAttnBlock(in_channels) # 线性注意力
- 灵活选择注意力类型,根据任务需求(精度vs效率)切换,增强模型适应性。
总结
这些基础模块是生成模型的“基础设施”:
- 时间步嵌入让模型感知扩散阶段;
- Swish+GroupNorm确保特征变换的稳定性和表达力;
- 上/下采样实现多尺度特征学习,适配音频/图像的分辨率需求;
- 残差块解决深层网络训练难题,同时融入时间信息;
- 注意力模块捕捉长距离依赖,增强全局特征建模能力。
它们的组合构成了扩散模型、VAE等复杂生成模型的主体框架,其设计细节(如音频专用的采样策略)直接影响模型对特定数据类型的适配性。
二、核心网络结构
核心网络结构是这套代码的“主体框架”,包括扩散模型主网络(Model
)、编码器(Encoder
)和解码器(Decoder
),分别对应生成式模型的三个核心功能:** 降噪生成 、 数据压缩编码 、 潜在变量解码重建**。这些结构基于前述基础模块(残差块、注意力、采样层等)搭建,针对时空数据(如音频频谱)的特性优化,下面详细解析:
1. 扩散模型主网络(Model
)
扩散模型(如DDPM)的核心是“降噪网络”,通过学习从含噪声数据中恢复原始数据的规律,实现生成。Model
类基于U-Net结构设计,支持时间步感知和注意力机制,是扩散过程的核心执行者。
1.1 初始化参数(__init__
)
def __init__(self,*,ch, # 基础通道数out_ch, # 输出通道数(与输入数据通道一致)ch_mult=(1, 2, 4, 8), # 各分辨率阶段的通道倍增因子num_res_blocks, # 每个分辨率阶段的残差块数量attn_resolutions, # 需要添加注意力的分辨率dropout=0.0, # dropout概率resamp_with_conv=True, # 采样时是否用卷积(而非池化)in_channels, # 输入数据通道数resolution, # 输入数据分辨率(如64x64)use_timestep=True, # 是否使用时间步嵌入(扩散模型必须)use_linear_attn=False, # 是否使用线性注意力attn_type="vanilla", # 注意力类型(普通/线性/无)
):
- 核心参数解析:
ch_mult
:控制网络深度,如(1,2,4,8)
表示4个分辨率阶段,通道数依次为ch×1
→ch×2
→ch×4
→ch×8
(逐步加深)。attn_resolutions
:指定在哪些分辨率下添加注意力(如(32, 16)
表示在32×32和16×16分辨率时使用注意力)。use_timestep
:扩散模型的标志,启用时间步嵌入以区分不同降噪阶段。
1.2 网络结构与前向传播(forward
)
网络整体遵循“下采样→中间层→上采样”的U-Net流程,结合时间步嵌入和跳连(Skip Connection)融合多尺度特征。
def forward(self, x, t=None, context=None):# 步骤1:处理条件输入(可选,如文本引导生成)if context is not None:x = torch.cat((x, context), dim=1) # 沿通道维度拼接输入与条件# 步骤2:生成时间步嵌入(扩散模型核心)if self.use_timestep:assert t is not None # 必须传入时间步ttemb = get_timestep_embedding(t, self.ch) # 生成基础嵌入temb = self.temb.dense[0](temb) # 全连接层加工temb = nonlinearity(temb)temb = self.temb.dense[1](temb) # 最终时间步嵌入(维度:temb_ch=ch×4)else:temb = None # 非扩散模型不使用# 步骤3:下采样(Downsampling):逐步降低分辨率,提取高层特征hs = [self.conv_in(x)] # 初始卷积:将输入通道转为基础通道chfor i_level in range(self.num_resolutions): # 遍历每个分辨率阶段# 每个阶段包含多个残差块for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb) # 残差块处理(融入时间步)if len(self.down[i_level].attn) > 0: # 若当前分辨率需注意力h = self.down[i_level].attn[i_block](h) # 注意力增强特征hs.append(h) # 保存特征用于后续跳连# 阶段结束时下采样(最后一个阶段不采样)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1])) # 分辨率减半# 步骤4:中间层(Middle):处理最深层特征h = hs[-1] # 取最深层特征h = self.mid.block_1(h, temb) # 残差块h = self.mid.attn_1(h) # 注意力h = self.mid.block_2(h, temb) # 残差块# 步骤5:上采样(Upsampling):逐步恢复分辨率,融合跳连特征for i_level in reversed(range(self.num_resolutions)): # 逆序遍历分辨率阶段# 每个阶段包含多个残差块,融合下采样时的跳连特征for i_block in range(self.num_res_blocks + 1):# 拼接当前特征与下采样阶段的对应特征(跳连)h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)if len(self.up[i_level].attn) > 0: # 注意力增强h = self.up[i_level].attn[i_block](h)# 阶段结束时上采样(第一个阶段不采样)if i_level != 0:h = self.up[i_level].upsample(h) # 分辨率翻倍# 步骤6:输出层:预测噪声或原始数据h = self.norm_out(h) # 归一化h = nonlinearity(h) # 激活h = self.conv_out(h) # 卷积输出(通道数=out_ch)return h
1.3 核心设计亮点
- U-Net结构:下采样(
down
)通过多层残差块和下采样层逐步压缩空间维度(如64→32→16→8),提取抽象特征;上采样(up
)通过上采样层和残差块逐步恢复维度,同时融合下采样阶段的跳连特征(hs.pop()
),保留细节信息。 - 时间步嵌入:时间步
t
通过get_timestep_embedding
转为向量后,融入每个残差块(h += self.temb_proj(...)
),使模型能学习不同噪声水平(不同t
)的降噪规律。 - 条件生成支持:通过
context
参数接收条件信号(如文本嵌入),与输入数据拼接后送入网络,实现条件生成(如“生成钢琴声”)。
2. 编码器(Encoder
)
编码器的作用是将原始数据(如音频频谱)压缩为潜在空间变量,通常用于VAE或作为扩散模型的“第一阶段”(将高维数据映射到低维潜在空间,降低生成难度)。其结构类似Model
的下采样部分,但无时间步嵌入(非扩散过程)。
2.1 初始化参数(__init__
)
在Model
参数基础上新增:
def __init__(self,...,z_channels, # 潜在变量通道数double_z=True, # 是否输出两倍通道(均值+对数方差,用于VAE的高斯分布)downsample_time_stride4_levels=[], # 哪些阶段使用音频专用下采样(时间步长4)...
):
double_z=True
是VAE的典型设计:输出通道数为2×z_channels
,前半为潜在变量的均值(mean
),后半为对数方差(logvar
),用于构建对角高斯分布(见前文DiagonalGaussianDistribution
)。
2.2 前向传播(forward
)
def forward(self, x):temb = None # 编码器无时间步(非扩散过程)# 下采样:与Model的下采样流程一致,但无时间步嵌入hs = [self.conv_in(x)] # 初始卷积for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb) # 残差块(无时间步)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:# 支持音频专用下采样(时间步长4)hs.append(self.down[i_level].downsample(hs[-1]))# 中间层:处理最深层特征h = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# 输出潜在变量参数h = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h) # 输出:2×z_channels(均值+对数方差)或z_channelsreturn h
2.3 核心作用
- 数据压缩:通过多层下采样将高分辨率输入(如64×64频谱)压缩为低分辨率潜在变量(如8×8),减少数据维度。
- 概率建模:输出潜在变量的分布参数(均值+对数方差),为VAE的“变分”特性提供基础(通过KL散度约束潜在分布接近标准正态分布)。
3. 解码器(Decoder
)
解码器是编码器的逆过程,将潜在变量重建为原始数据(如从VAE的潜在变量重建音频频谱)。结构与Model
的上采样部分类似,同样无时间步嵌入。
3.1 初始化参数(__init__
)
在Encoder
参数基础上新增:
def __init__(self,...,give_pre_end=False, # 是否返回最终卷积前的特征tanh_out=False, # 是否用tanh约束输出范围(如[-1,1])...
):
tanh_out=True
用于约束输出在[-1,1]
范围内,适合输入数据已归一化的场景(如音频频谱通常归一化到该范围)。
3.2 前向传播(forward
)
def forward(self, z):temb = None # 无时间步# 潜在变量映射到高层特征h = self.conv_in(z) # 将潜在变量通道转为解码器基础通道# 中间层:处理高层特征h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# 上采样:逐步恢复分辨率(与编码器下采样对称)for i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](h, temb) # 无跳连(与Encoder对称,无需融合)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:# 支持音频专用上采样(时间×4,频率×2)h = self.up[i_level].upsample(h)# 输出重建结果if self.give_pre_end:return h # 返回最终卷积前的特征(用于中间任务)h = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h) # 输出通道=out_ch(与原始数据一致)if self.tanh_out:h = torch.tanh(h) # 约束范围return h
3.3 核心作用
- 重建数据:从潜在变量恢复原始数据的分辨率和细节,与编码器形成“编码-解码”闭环(如VAE的重建损失基于解码器输出与输入的差异)。
- 对称性设计:与编码器的下采样阶段严格对称(分辨率变化、通道数变化一致),确保潜在变量能被准确解码。
核心网络的协同关系
这三个网络并非孤立存在,而是通过“组合”实现复杂生成任务:
- VAE模式:
Encoder
编码输入为潜在分布→采样潜在变量→Decoder
重建输入,通过重建损失+KL散度训练,学习数据的压缩表示。 - 扩散+VAE模式:先用
Encoder
将数据映射到低维潜在空间→在潜在空间用Model
(扩散模型)学习生成→再用Decoder
将生成的潜在变量解码为最终数据(降低扩散模型的计算成本)。
这种组合兼顾了VAE的压缩能力和扩散模型的生成质量,是现代生成式模型(如Stable Diffusion、AudioLDM)的主流架构。
总结
核心网络结构围绕“生成”与“重建”设计:
Model
作为扩散模型的降噪核心,通过U-Net结构和时间步嵌入学习从噪声中恢复数据的规律;Encoder
和Decoder
构成自编码器,实现数据与潜在空间的双向映射,支持压缩、重建和作为扩散模型的前置处理。
三者均基于残差块、注意力和采样层构建,兼顾深层特征提取、长距离依赖捕捉和多尺度信息融合,为音频/图像等时空数据的生成任务提供了强大的基础架构。
三、辅助模块
辅助模块是对核心网络的补充与扩展,主要解决特定场景需求,如分辨率调整、轻量级重建、预训练模型融合等,增强了整个框架的灵活性和适用性(尤其针对音频生成的特殊需求)。以下是详细解析:
1. 简化解码器(SimpleDecoder
)
轻量化解码器,用于快速重建或低复杂度场景,通过少量残差块和上采样实现基础的特征恢复。
1.1 结构与前向传播
class SimpleDecoder(nn.Module):def __init__(self, in_channels, out_channels, *args, **kwargs):super().__init__()# 模块列表:1x1卷积→残差块×3→1x1卷积→上采样self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), # 通道调整ResnetBlock(in_channels=in_channels, out_channels=2*in_channels, temb_channels=0, dropout=0.0),ResnetBlock(in_channels=2*in_channels, out_channels=4*in_channels, temb_channels=0, dropout=0.0),ResnetBlock(in_channels=4*in_channels, out_channels=2*in_channels, temb_channels=0, dropout=0.0),nn.Conv2d(2*in_channels, in_channels, 1), # 通道压缩Upsample(in_channels, with_conv=True) # 上采样(尺度×2)])# 输出层self.norm_out = Normalize(in_channels)self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):for i, layer in enumerate(self.model):if i in [1, 2, 3]: # 残差块需要传入temb(此处为None,非扩散过程)x = layer(x, None)else:x = layer(x)# 最终处理h = self.norm_out(x)h = nonlinearity(h)return self.conv_out(h)
1.2 核心作用
- 轻量重建:相比完整
Decoder
,用更少的残差块(3个)实现基础的特征解码,适合资源受限场景或作为辅助解码器。 - 固定流程:结构固定(通道先扩张后压缩+上采样),无需复杂参数配置,快速适配简单重建任务(如音频频谱的粗略恢复)。
2. 上采样解码器(UpsampleDecoder
)
专注于分辨率提升的解码器,通过多阶段残差块和上采样,逐步将低分辨率特征恢复为高分辨率输出。
2.1 结构与前向传播
class UpsampleDecoder(nn.Module):def __init__(self,in_channels,out_channels,ch, # 基础通道数num_res_blocks, # 每个阶段的残差块数量resolution, # 目标分辨率ch_mult=(2, 2), # 通道倍增因子dropout=0.0,):super().__init__()self.temb_ch = 0 # 无时间步self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksblock_in = in_channels # 输入通道curr_res = resolution // (2 ** (self.num_resolutions - 1)) # 初始分辨率(逐步上采样至目标)# 残差块组(每个阶段一组)和上采样层self.res_blocks = nn.ModuleList()self.upsample_blocks = nn.ModuleList()for i_level in range(self.num_resolutions):res_block = []block_out = ch * ch_mult[i_level] # 当前阶段输出通道for i_block in range(self.num_res_blocks + 1):res_block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout))block_in = block_outself.res_blocks.append(nn.ModuleList(res_block))# 除最后一个阶段外,添加上采样层if i_level != self.num_resolutions - 1:self.upsample_blocks.append(Upsample(block_in, True))curr_res *= 2 # 分辨率翻倍# 输出层self.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):h = x# 逐阶段处理:残差块→上采样(最后阶段不上采样)for k, i_level in enumerate(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.res_blocks[i_level][i_block](h, None) # 残差块处理if i_level != self.num_resolutions - 1:h = self.upsample_blocks[k](h) # 上采样# 输出处理h = self.norm_out(h)h = nonlinearity(h)return self.conv_out(h)
2.2 核心作用
-** 渐进式上采样 :通过多阶段残差块和上采样,逐步将低分辨率特征(如8×8)恢复为目标分辨率(如32×32),避免单次大尺度上采样导致的细节丢失。
- 灵活配置 **:通过ch_mult
和num_res_blocks
控制网络深度和通道扩张,适配不同分辨率提升需求(如音频频谱的时间维度拉长)。
3. 潜在空间尺度调整器(LatentRescaler
调整潜在变量的分辨率(缩放尺寸),同时通过残差块和注意力保持特征一致性,用于适配不同网络间的潜在空间尺度差异。
3.1 结构与前向传播
class LatentRescaler(nn.Module):def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):super().__init__()self.factor = factor # 缩放因子(如1.5倍)# 输入卷积:调整通道至中间维度self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)# 残差块组1:缩放前增强特征self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,) for _ in range(depth)])self.attn = AttnBlock(mid_channels) # 注意力:捕捉长距离依赖# 残差块组2:缩放后修复特征self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,) for _ in range(depth)])# 输出卷积:调整通道至目标维度self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1)def forward(self, x):x = self.conv_in(x) # 通道调整# 缩放前特征增强for block in self.res_block1:x = block(x, None)# 尺度调整(插值缩放)x = torch.nn.functional.interpolate(x,size=(int(round(x.shape[2] * self.factor)), # 高度×缩放因子int(round(x.shape[3] * self.factor)) # 宽度×缩放因子),)x = self.attn(x).contiguous() # 注意力修复缩放后的特征一致性# 缩放后特征增强for block in self.res_block2:x = block(x, None)return self.conv_out(x) # 输出目标通道
3.2 核心作用
-** 跨尺度适配 :解决不同网络输出的潜在变量分辨率不匹配问题(如将VAE的16×16潜在变量缩放到扩散模型所需的24×24)。
- 特征保持 **:缩放前后通过残差块增强特征,中间用注意力修复缩放导致的空间一致性丢失,确保调整后的潜在变量仍保留关键信息。
4. 融合尺度调整的编码器/解码器
将基础编码器/解码器与LatentRescaler
结合,实现“编码+尺度调整”或“尺度调整+解码”的端到端流程,简化多阶段生成 pipeline。
4.1 MergedRescaleEncoder
(编码+缩放)
class MergedRescaleEncoder(nn.Module):def __init__(self,in_channels,ch,resolution,out_ch, # 最终输出通道num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,ch_mult=(1, 2, 4, 8),rescale_factor=1.0, # 缩放因子rescale_module_depth=1, # 残差块深度):super().__init__()intermediate_chn = ch * ch_mult[-1] # 编码器输出通道# 基础编码器self.encoder = Encoder(in_channels=in_channels,num_res_blocks=num_res_blocks,ch=ch,ch_mult=ch_mult,z_channels=intermediate_chn,double_z=False, # 不输出均值+方差,仅输出特征resolution=resolution,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,out_ch=None,)# 编码后缩放self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=intermediate_chn,mid_channels=intermediate_chn,out_channels=out_ch,depth=rescale_module_depth,)def forward(self, x):x = self.encoder(x) # 先编码x = self.rescaler(x) # 再缩放return x
4.2 MergedRescaleDecoder
(缩放+解码)
class MergedRescaleDecoder(nn.Module):def __init__(self,z_channels, # 输入潜在变量通道out_ch, # 输出通道resolution,num_res_blocks,attn_resolutions,ch,ch_mult=(1, 2, 4, 8),dropout=0.0,resamp_with_conv=True,rescale_factor=1.0,rescale_module_depth=1,):super().__init__()tmp_chn = z_channels * ch_mult[-1] # 中间通道# 基础解码器self.decoder = Decoder(out_ch=out_ch,z_channels=tmp_chn,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,in_channels=None,num_res_blocks=num_res_blocks,ch_mult=ch_mult,resolution=resolution,ch=ch,)# 解码前缩放self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=z_channels,mid_channels=tmp_chn,out_channels=tmp_chn,depth=rescale_module_depth,)def forward(self, x):x = self.rescaler(x) # 先缩放x = self.decoder(x) # 再解码return x
4.3 核心作用
- 流程简化:将“编码→缩放”或“缩放→解码”合并为单一模块,减少多阶段模型的调用复杂度(如VAE编码后直接适配扩散模型的潜在空间)。
- 参数联动:确保编码器/解码器与缩放器的通道、分辨率参数匹配,避免手动调整的错误。
5. 通用上采样器(Upsampler
)
针对特定分辨率提升需求设计的专用上采样器,结合LatentRescaler
和Decoder
,实现从低分辨率到高分辨率的端到端生成。
class Upsampler(nn.Module):def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):super().__init__()assert out_size >= in_size # 确保输出分辨率不小于输入# 计算上采样阶段数(基于2的对数)num_blocks = int(np.log2(out_size // in_size)) + 1# 初始缩放因子(处理非2的幂次倍数)factor_up = 1.0 + (out_size % in_size)# 先缩放再解码self.rescaler = LatentRescaler(factor=factor_up,in_channels=in_channels,mid_channels=2 * in_channels,out_channels=in_channels,)self.decoder = Decoder(out_ch=out_channels,resolution=out_size,z_channels=in_channels,num_res_blocks=2,attn_resolutions=[],in_channels=None,ch=in_channels,ch_mult=[ch_mult for _ in range(num_blocks)], # 每个阶段通道倍增)def forward(self, x):x = self.rescaler(x) # 初步缩放x = self.decoder(x) # 解码至目标分辨率return x
5.1 核心作用
- 目标导向:直接指定输入(
in_size
)和输出(out_size
)分辨率,自动计算所需上采样阶段和因子,无需手动配置。 - 专用优化:针对分辨率提升场景(如音频从32kHz升至44.1kHz)优化,平衡速度与质量。
6. 分辨率调整器(Resize
)
轻量级分辨率调整工具,支持学习式或固定插值方式调整特征图尺度,用于简单的分辨率适配。
class Resize(nn.Module):def __init__(self, in_channels=None, learned=False, mode="bilinear"):super().__init__()self.with_conv = learned # 是否用学习式卷积调整self.mode = mode # 插值方式(双线性/近邻等)if self.with_conv:# 学习式下采样(未实现,预留接口)raise NotImplementedError()self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)def forward(self, x, scale_factor=1.0):if scale_factor == 1.0:return x # 无需调整else:# 固定插值调整(非学习式)x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)return x
6.1 核心作用
- 轻量适配:用于无需复杂特征修复的场景(如中间特征的临时分辨率调整),计算成本低。
- 灵活模式:支持双线性、近邻等多种插值方式,适配不同数据的平滑度需求(如音频频谱用近邻保留锐度)。
7. 预训练模型后处理器(FirstStagePostProcessor
)
将预训练模型(如VAE)的输出映射到目标网络(如扩散模型)的潜在空间,实现多模型协同工作(如“预训练VAE+扩散模型”的两阶段生成)。
7.1 结构与前向传播
class FirstStagePostProcessor(nn.Module):def __init__(self,ch_mult: list, # 通道倍增因子in_channels,pretrained_model: nn.Module = None, # 预训练模型(如VAE)reshape=False, # 是否重塑特征为序列格式n_channels=None, # 中间通道数dropout=0.0,pretrained_config=None, # 预训练模型配置(用于实例化)):super().__init__()# 加载预训练模型(冻结参数,不参与训练)if pretrained_config is not None:self.instantiate_pretrained(pretrained_config)else:self.pretrained_model = pretrained_modelself.do_reshape = reshape# 通道投影:将预训练模型输出映射到中间通道if n_channels is None:n_channels = self.pretrained_model.encoder.chself.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, stride=1, padding=1)# 特征处理:残差块+下采样,适配目标网络的潜在空间blocks = []downs = []ch_in = n_channelsfor m in ch_mult:blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout))ch_in = m * n_channelsdowns.append(Downsample(ch_in, with_conv=False))self.model = nn.ModuleList(blocks)self.downsampler = nn.ModuleList(downs)def instantiate_pretrained(self, config):# 从配置实例化预训练模型并冻结self.pretrained_model = instantiate_from_config(config).eval()for param in self.pretrained_model.parameters():param.requires_grad = False@torch.no_grad() # 预训练模型不计算梯度def encode_with_pretrained(self, x):# 用预训练模型编码(如VAE的编码器)c = self.pretrained_model.encode(x)if isinstance(c, DiagonalGaussianDistribution): # 若输出为分布,取众数c = c.mode()return cdef forward(self, x):# 步骤1:预训练模型编码z_fs = self.encode_with_pretrained(x)# 步骤2:通道投影与增强z = self.proj_norm(z_fs)z = self.proj(z)z = nonlinearity(z)# 步骤3:特征处理(残差块+下采样)for submodel, downmodel in zip(self.model, self.downsampler):z = submodel(z, temb=None)z = downmodel(z)# 可选:重塑为序列格式(如用于Transformer)if self.do_reshape:z = rearrange(z, "b c h w -> b (h w) c")return z
7.2 核心作用
- 模型桥接:解决预训练模型与下游模型的潜在空间不兼容问题(如将VAE的潜在变量转换为扩散模型可处理的格式)。
- 特征适配:通过投影、残差块和下采样,调整预训练模型输出的通道和分辨率,使其符合下游模型的输入要求。
- 冻结预训练:预训练模型参数固定(
requires_grad=False
),仅训练适配层,避免破坏预训练知识。
总结
辅助模块通过以下方式增强了框架的实用性:
- 灵活性:提供轻量级解码器(
SimpleDecoder
)、通用上采样器(Upsampler
)等,适配不同复杂度需求。 - 兼容性:通过
LatentRescaler
、MergedRescaleEncoder/Decoder
解决不同网络间的尺度/通道不匹配问题。 - 扩展性:
FirstStagePostProcessor
支持融合预训练模型,实现“预训练+微调”的两阶段生成,降低训练成本。
这些模块与核心网络(Model
、Encoder
、Decoder
)协同,形成了一套完整的生成式模型工具链,尤其适合音频等时空数据的生成、重建和尺度调整任务。
总结
这段代码构建了一套完整的生成式模型工具链,核心包括:
- 基础组件:残差块、注意力、上/下采样、时间步嵌入,适配时空数据(音频/图像)。
- 核心网络:扩散模型主网络(
Model
)、VAE风格的编码器(Encoder
)和解码器(Decoder
)。 - 辅助工具:尺度调整、预训练模型融合等模块,支持复杂生成流程。
整体设计聚焦于灵活性和音频处理特性(如时间-频率维度的差异化采样),可用于音频生成、降噪、重建等任务,是AudioLDM等音频生成模型的核心组件。
autoencoder.py
先给出完整代码:
import torch
from audioldm.latent_diffusion.ema import *
from audioldm.variational_autoencoder.modules import Encoder, Decoder
from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistributionfrom audioldm.hifigan.utilities import get_vocoder, vocoder_inferclass AutoencoderKL(nn.Module):def __init__(self,ddconfig=None,lossconfig=None,image_key="fbank",embed_dim=None,time_shuffle=1,subband=1,ckpt_path=None,reload_from_ckpt=None,ignore_keys=[],colorize_nlabels=None,monitor=None,base_learning_rate=1e-5,):super().__init__()self.encoder = Encoder(**ddconfig)self.decoder = Decoder(**ddconfig)self.subband = int(subband)if self.subband > 1:print("Use subband decomposition %s" % self.subband)self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)self.vocoder = get_vocoder(None, "cpu")self.embed_dim = embed_dimif monitor is not None:self.monitor = monitorself.time_shuffle = time_shuffleself.reload_from_ckpt = reload_from_ckptself.reloaded = Falseself.mean, self.std = None, Nonedef encode(self, x):# x = self.time_shuffle_operation(x)x = self.freq_split_subband(x)h = self.encoder(x)moments = self.quant_conv(h)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z):z = self.post_quant_conv(z)dec = self.decoder(z)dec = self.freq_merge_subband(dec)return decdef decode_to_waveform(self, dec):dec = dec.squeeze(1).permute(0, 2, 1)wav_reconstruction = vocoder_infer(dec, self.vocoder)return wav_reconstructiondef forward(self, input, sample_posterior=True):posterior = self.encode(input)if sample_posterior:z = posterior.sample()else:z = posterior.mode()if self.flag_first_run:print("Latent size: ", z.size())self.flag_first_run = Falsedec = self.decode(z)return dec, posteriordef freq_split_subband(self, fbank):if self.subband == 1 or self.image_key != "stft":return fbankbs, ch, tstep, fbins = fbank.size()assert fbank.size(-1) % self.subband == 0assert ch == 1return (fbank.squeeze(1).reshape(bs, tstep, self.subband, fbins // self.subband).permute(0, 2, 1, 3))def freq_merge_subband(self, subband_fbank):if self.subband == 1 or self.image_key != "stft":return subband_fbankassert subband_fbank.size(1) == self.subband # Channel dimensionbs, sub_ch, tstep, fbins = subband_fbank.size()return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
这段代码实现了一个基于PyTorch的AutoencoderKL
(KL散度正则化的自编码器),主要用于音频信号处理(从代码中的audioldm
相关依赖可以看出)。该模型结合了变分自编码器(VAE)的思想,包含编码器、解码器、潜在变量量化等组件,并集成了语音合成相关的功能。以下是详细解释:
1. 类与依赖说明
- 基类:
AutoencoderKL
继承自torch.nn.Module
,是PyTorch中所有神经网络模块的基类。 - 依赖组件:
- 从
audioldm
导入的编码器(Encoder
)、解码器(Decoder
):用于音频特征的编码和解码。 - 对角高斯分布(
DiagonalGaussianDistribution
):用于建模潜在变量的概率分布(VAE的核心)。 - 语音解码器(
vocoder
):用于将模型输出的频谱特征转换为可听的音频波形。
- 从
2. 初始化方法(__init__
)
AutoencoderKL
类的__init__
方法是模型的构造函数,用于初始化模型的核心组件、超参数和配置信息。它定义了模型的基本结构和依赖关系,为后续的编码、解码等操作奠定基础。以下是对该方法的逐部分详细解析:
方法定义与参数说明
方法定义如下:
def __init__(self,ddconfig=None,lossconfig=None,image_key="fbank",embed_dim=None,time_shuffle=1,subband=1,ckpt_path=None,reload_from_ckpt=None,ignore_keys=[],colorize_nlabels=None,monitor=None,base_learning_rate=1e-5,
):
参数含义:
ddconfig
:编码器(Encoder
)和解码器(Decoder
)的配置字典(如网络层数、通道数等核心参数)。lossconfig
:损失函数的配置字典(当前代码未直接使用,预留用于定义损失计算方式)。image_key
:输入特征的键名(如"fbank"表示梅尔频谱,"stft"表示短时傅里叶变换频谱),用于区分不同类型的输入特征。embed_dim
:潜在变量(latent variable)的维度(潜在空间的维度)。time_shuffle
:时间维度的打乱参数(当前代码未启用实际操作,预留用于数据增强)。subband
:频谱的子带分割数量(用于将频谱沿频率轴分割为多个子带,提升计算效率)。ckpt_path
/reload_from_ckpt
:模型权重检查点路径(用于加载预训练权重)。ignore_keys
:加载检查点时需要忽略的参数名列表(避免因参数不匹配导致的错误)。colorize_nlabels
:预留参数(可能用于多类别特征的颜色映射,音频任务中较少用到)。monitor
:训练时需要监控的指标(如重建损失),用于日志记录或早停策略。base_learning_rate
:基础学习率(用于训练时的优化器配置)。
初始化核心逻辑
1. 调用父类构造函数
super().__init__()
这是PyTorch中定义神经网络模块的标准操作,用于初始化父类nn.Module
的内部状态(如参数管理、设备配置等),确保模型能正常使用PyTorch的自动求导、参数保存等功能。
2. 初始化编码器与解码器
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(** ddconfig)
Encoder
和Decoder
是从audioldm.variational_autoencoder.modules
导入的网络模块,专门用于音频特征的编码和解码。**ddconfig
表示将ddconfig
字典中的键值对作为关键字参数传递给Encoder
和Decoder
的构造函数。例如,ddconfig
可能包含z_channels
(潜在特征通道数)、channels
(网络层通道数)等参数,用于定义编码器和解码器的网络结构(如卷积层数量、每层的输出通道等)。- 作用:编码器负责将输入音频特征(如频谱)压缩为高维特征;解码器负责将潜在变量还原为重建的音频特征。
3. 配置子带分解
self.subband = int(subband)
if self.subband > 1:print("Use subband decomposition %s" % self.subband)
subband
参数用于控制是否对频谱进行子带分割(仅当输入特征为STFT频谱时生效)。例如,若subband=4
,则将频谱的频率轴分为4个子带,每个子带处理一部分频率范围。- 作用:子带分解可降低单个子带的频率维度,减少计算量,同时适应音频信号在不同频率范围内的特性差异。
4. 定义潜在变量的维度转换卷积层
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
这两个1x1卷积层是连接编码器、潜在空间和解码器的关键组件:
quant_conv
:输入通道数为2 * ddconfig["z_channels"]
(编码器输出的特征通道数的2倍,对应均值和方差两个参数),输出通道数为2 * embed_dim
(潜在分布的均值和方差的总维度)。作用是将编码器输出的高维特征转换为潜在变量的概率分布参数(均值和方差)。post_quant_conv
:输入通道数为embed_dim
(潜在变量的维度),输出通道数为ddconfig["z_channels"]
(解码器所需的输入通道数)。作用是将潜在变量的维度调整为解码器可接受的输入维度。- 1x1卷积的特点:不改变特征的空间维度(仅调整通道数),高效完成维度转换。
5. 初始化声码器(Vocoder)
self.vocoder = get_vocoder(None, "cpu")
get_vocoder
是从audioldm.hifigan.utilities
导入的工具函数,用于获取预训练的声码器(如HiFi-GAN)。声码器的作用是将模型输出的频谱特征(如STFT、梅尔频谱)转换为可听的音频波形。- 初始设备设为
"cpu"
,后续可根据需要迁移到GPU。
6. 存储核心属性
self.embed_dim = embed_dim # 潜在变量的维度
if monitor is not None:self.monitor = monitor # 训练时监控的指标
self.time_shuffle = time_shuffle # 时间打乱参数(预留)
self.reload_from_ckpt = reload_from_ckpt # 检查点路径(用于后续加载)
self.reloaded = False # 标记是否已加载检查点(初始为未加载)
self.mean, self.std = None, None # 预留用于存储数据的均值和标准差(可能用于归一化)
这些属性用于记录模型的配置信息、状态或预留功能(如数据归一化的均值/标准差、检查点加载状态等)。
总结
__init__
方法的核心作用是“搭骨架”:通过初始化编码器、解码器、维度转换卷积层、声码器等组件,定义了AutoencoderKL
的基本结构。它将输入的配置参数(如ddconfig
、embed_dim
)映射为模型的可训练组件和状态变量,为后续的encode
(编码)、decode
(解码)等方法提供了基础。
特别地,该方法针对音频处理场景做了优化:
- 支持子带分解以提升效率;
- 集成声码器以直接生成音频波形;
- 通过潜在变量的概率分布建模(后续在
encode
中实现),体现了变分自编码器(VAE)的核心思想。
3. 核心方法解析
AutoencoderKL
的核心方法实现了模型的核心功能流程:从输入音频特征的编码(得到潜在分布)、潜在变量的采样、解码(重建特征),到最终转换为可听波形,以及辅助的子带处理。以下是对这些核心方法的逐句解析:
3.1 编码方法(encode
):将输入特征映射为潜在分布
def encode(self, x):# x = self.time_shuffle_operation(x) # 注释:预留的时间维度打乱操作(未启用)x = self.freq_split_subband(x) # 对输入特征进行子带分割(可选)h = self.encoder(x) # 编码器输出高维特征moments = self.quant_conv(h) # 转换为潜在分布的参数(均值和方差)posterior = DiagonalGaussianDistribution(moments) # 构建对角高斯分布return posterior # 返回后验分布
详细解析:
- 输入
x
:通常是音频的频谱特征(如梅尔频谱fbank
或STFT频谱),形状一般为(batch_size, channels, time_steps, freq_bins)
(批量大小、通道数、时间步、频率 bins)。 - 子带分割(
self.freq_split_subband(x)
):
若subband > 1
且输入为STFT频谱(image_key="stft"
),则将频谱沿频率轴分割为多个子带(见3.5节),降低单个子带的频率维度,减少计算量。若不满足条件,则直接返回x
。 - 编码器处理(
self.encoder(x)
):
编码器(Encoder
)是一个深度神经网络(通常由卷积层构成),将分割后的子带特征x
压缩为更高维的特征h
(形状与ddconfig
配置相关,如(batch_size, z_channels, ...)
)。 - 分布参数转换(
self.quant_conv(h)
):
quant_conv
是1x1卷积层,将编码器输出的h
(通道数为2*z_channels
,因为VAE中通常需要输出均值和方差两组参数)转换为潜在分布的参数moments
,通道数变为2*embed_dim
(embed_dim
是潜在变量维度,两组参数分别对应均值和方差)。 - 构建后验分布(
DiagonalGaussianDistribution(moments)
):
moments
被拆分为均值(mean
)和对数方差(logvar
),构建一个对角高斯分布(各维度独立的高斯分布)。这是VAE的核心:编码器不直接输出潜在变量,而是输出其概率分布,体现“变分”特性。 - 返回值:后验分布
posterior
,后续可从该分布中采样得到潜在变量z
。
3.2 解码方法(decode
):将潜在变量重建为特征
def decode(self, z):z = self.post_quant_conv(z) # 调整潜在变量维度以匹配解码器输入dec = self.decoder(z) # 解码器输出重建特征dec = self.freq_merge_subband(dec) # 合并子带(与编码时的分割对应)return dec # 返回重建特征
详细解析:
- 输入
z
:从潜在分布中采样的潜在变量,形状为(batch_size, embed_dim, time_steps, freq_bins_sub)
(子带分割后的频率维度)。 - 维度调整(
self.post_quant_conv(z)
):
post_quant_conv
是1x1卷积层,将潜在变量z
的通道数从embed_dim
转换为解码器所需的z_channels
(与编码器的输出通道数匹配),确保维度兼容。 - 解码器处理(
self.decoder(z)
):
解码器(Decoder
)是编码器的逆过程,通过卷积层(通常包含转置卷积实现上采样)将调整后的潜在变量z
还原为子带分割后的重建特征dec
(形状与编码时的子带特征对应)。 - 子带合并(
self.freq_merge_subband(dec)
):
若编码时进行了子带分割,此处需将子带特征合并回原始频谱维度(见3.5节),确保输出特征的形状与输入x
一致。 - 返回值:重建的音频特征
dec
,形状与输入x
相同(如(batch_size, channels, time_steps, freq_bins)
)。
3.3 波形转换方法(decode_to_waveform
):将重建特征转为可听音频
def decode_to_waveform(self, dec):dec = dec.squeeze(1).permute(0, 2, 1) # 调整特征维度以匹配声码器输入wav_reconstruction = vocoder_infer(dec, self.vocoder) # 声码器生成波形return wav_reconstruction
详细解析:
- 输入
dec
:解码后的重建特征(如STFT或梅尔频谱),形状为(batch_size, 1, time_steps, freq_bins)
(通常单通道)。 - 维度调整:
squeeze(1)
:移除通道维度(因单通道无意义),形状变为(batch_size, time_steps, freq_bins)
。permute(0, 2, 1)
:交换时间和频率维度,形状变为(batch_size, freq_bins, time_steps)
,以匹配声码器的输入格式(声码器通常要求(batch, freq, time)
)。
- 声码器生成波形(
vocoder_infer
):
声码器(如HiFi-GAN)是专门将频谱特征转换为音频波形的模型。vocoder_infer
函数调用预训练的self.vocoder
,将调整后的dec
转换为波形信号wav_reconstruction
,形状为(batch_size, 1, sample_length)
(单声道音频,采样点数为sample_length
)。 - 作用:频谱特征是“视觉”特征(不可直接听),通过该方法将其转换为可听的音频波形,是音频生成任务的最终输出环节。
3.4 前向传播方法(forward
):完整的编码-采样-解码流程
def forward(self, input, sample_posterior=True):posterior = self.encode(input) # 编码得到后验分布# 从后验分布中采样(训练时)或取均值(推理时)if sample_posterior:z = posterior.sample()else:z = posterior.mode()# 首次运行时打印潜在变量尺寸(仅执行一次)if self.flag_first_run:print("Latent size: ", z.size())self.flag_first_run = Falsedec = self.decode(z) # 解码得到重建特征return dec, posterior # 返回重建结果和后验分布
详细解析:
- 输入参数:
input
:原始音频特征(与encode
方法的输入x
一致)。sample_posterior
:布尔值,控制是否从后验分布中采样(True
)或直接取均值(False
)。
- 核心流程:
- 编码:调用
encode
得到后验分布posterior
。 - 采样/取均值:
- 训练时(
sample_posterior=True
):从posterior
中随机采样z
,引入随机性,确保潜在空间的连续性(符合VAE的“变分”约束)。 - 推理时(
sample_posterior=False
):直接取分布的均值(posterior.mode()
),避免随机性导致的输出不稳定。
- 训练时(
- 解码:调用
decode
将z
重建为特征dec
。 - 首次运行打印:通过
self.flag_first_run
(需提前定义为True
)打印潜在变量z
的尺寸,方便调试网络结构。
- 编码:调用
- 返回值:
dec
:重建的音频特征(用于计算重建损失,衡量与输入input
的差异)。posterior
:后验分布(用于计算KL散度损失,约束潜在分布接近先验分布,通常是标准高斯分布)。
3.5 子带处理方法:频谱分割与合并
子带处理是针对STFT频谱的优化,通过分割频率轴降低计算量,仅当subband > 1
且image_key="stft"
时生效。
3.5.1 频谱分割(freq_split_subband
)
def freq_split_subband(self, fbank):# 若不满足子带分割条件,直接返回原始特征if self.subband == 1 or self.image_key != "stft":return fbank# 获取输入特征的尺寸:(batch_size, channels, time_steps, freq_bins)bs, ch, tstep, fbins = fbank.size()# 校验:频率 bins 必须能被子带数整除,且输入必须是单通道assert fbank.size(-1) % self.subband == 0assert ch == 1# 分割逻辑:将频率轴拆分为subband个子带return (fbank.squeeze(1) # 移除单通道维度:(bs, tstep, fbins).reshape(bs, tstep, self.subband, fbins // self.subband) # 拆分频率轴:(bs, tstep, subband, fbins/subband).permute(0, 2, 1, 3) # 调整维度顺序:(bs, subband, tstep, fbins/subband)(子带作为新的通道维度))
- 作用:将原始频谱的频率轴(
fbins
)均匀分割为subband
个子带,每个子带的频率维度为fbins/subband
,并将子带作为新的通道维度,便于并行处理。
3.5.2 频谱合并(freq_merge_subband
)
def freq_merge_subband(self, subband_fbank):# 若不满足子带合并条件,直接返回子带特征if self.subband == 1 or self.image_key != "stft":return subband_fbank# 校验:子带特征的通道数必须等于子带数assert subband_fbank.size(1) == self.subband # Channel dimension# 获取子带特征的尺寸:(batch_size, subband, time_steps, fbins_sub)bs, sub_ch, tstep, fbins = subband_fbank.size()# 合并逻辑:将子带还原为原始频率轴return subband_fbank.permute(0, 2, 1, 3) # 调整维度顺序:(bs, tstep, subband, fbins_sub).reshape(bs, tstep, -1) # 合并子带:(bs, tstep, subband*fbins_sub) = (bs, tstep, fbins).unsqueeze(1) # 恢复通道维度:(bs, 1, tstep, fbins)
- 作用:与
freq_split_subband
互逆,将子带特征重新合并为原始频谱维度,确保解码输出与输入特征的形状一致。
核心方法的协同关系
这些方法共同构成了AutoencoderKL
的完整功能链:
输入特征 → encode(子带分割→编码→分布建模)→ 采样潜在变量 → decode(解码→子带合并)→ 重建特征 → decode_to_waveform(转为音频波形)
其中,forward
是对外接口,串联了编码-采样-解码流程;子带处理是可选的效率优化;decode_to_waveform
则是音频任务特有的“最后一公里”转换,将抽象特征转为可感知的音频。这种设计既遵循了VAE的概率建模思想,又针对音频信号的特性做了专门优化。
4. 模型核心思想
这是一个变分自编码器(VAE) 的变种,核心特点是:
- 潜在变量的概率建模:编码器输出的是潜在变量的分布(而非确定值),通过采样引入随机性,增强模型的泛化能力。
- KL散度正则化:训练时会通过后验分布(
posterior
)与先验分布(通常是标准高斯分布)的KL散度,约束潜在空间的分布特性。 - 音频专用优化:
- 子带处理:降低频谱维度,适应音频信号的频率特性。
- 集成
vocoder
:直接支持从特征到波形的转换,方便音频生成任务。
5. 应用场景
该模型可能用于音频生成、音频降噪、声音转换等任务。例如:
- 训练时,模型学习将输入音频(如语音、音乐)编码为潜在变量,再重建输入,通过最小化重建损失+KL损失优化。
- 推理时,可从先验分布中采样潜在变量,解码得到新的音频(生成任务);或对噪声音频编码后解码,实现降噪。
总结:这段代码实现了一个面向音频处理的KL正则化自编码器,结合了VAE的概率建模和音频专用的子带处理、波形转换功能,是音频生成模型(如AudioLDM)的核心组件之一。