当前位置: 首页 > news >正文

audioLDM模型代码阅读(三)——变分自编码器VAE

distributions.py

先给出完整代码:

import torch
import numpy as npclass AbstractDistribution:def sample(self):raise NotImplementedError()def mode(self):raise NotImplementedError()class DiracDistribution(AbstractDistribution):def __init__(self, value):self.value = valuedef sample(self):return self.valuedef mode(self):return self.valueclass DiagonalGaussianDistribution(object):def __init__(self, parameters, deterministic=False):self.parameters = parametersself.mean, self.logvar = torch.chunk(parameters, 2, dim=1)self.logvar = torch.clamp(self.logvar, -30.0, 20.0)self.deterministic = deterministicself.std = torch.exp(0.5 * self.logvar)self.var = torch.exp(self.logvar)if self.deterministic:self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)def sample(self):x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)return xdef kl(self, other=None):if self.deterministic:return torch.Tensor([0.0])else:if other is None:return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,dim=[1, 2, 3],)else:return 0.5 * torch.mean(torch.pow(self.mean - other.mean, 2) / other.var+ self.var / other.var- 1.0- self.logvar+ other.logvar,dim=[1, 2, 3],)def nll(self, sample, dims=[1, 2, 3]):if self.deterministic:return torch.Tensor([0.0])logtwopi = np.log(2.0 * np.pi)return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,dim=dims,)def mode(self):return self.meandef normal_kl(mean1, logvar1, mean2, logvar2):"""source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12Compute the KL divergence between two gaussians.Shapes are automatically broadcasted, so batches can be compared toscalars, among other use cases."""tensor = Nonefor obj in (mean1, logvar1, mean2, logvar2):if isinstance(obj, torch.Tensor):tensor = objbreakassert tensor is not None, "at least one argument must be a Tensor"# Force variances to be Tensors. Broadcasting helps convert scalars to# Tensors, but it does not work for torch.exp().logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)for x in (logvar1, logvar2)]return 0.5 * (-1.0+ logvar2- logvar1+ torch.exp(logvar1 - logvar2)+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2))

这段代码主要实现了几种概率分布的类和相关工具函数,尤其专注于对角高斯分布(Diagonal Gaussian Distribution) 的实现,这类分布在变分自编码器(VAE)、扩散模型等生成式模型中广泛用于潜在变量的概率建模。以下是详细解析:

1. 抽象基类 AbstractDistribution

class AbstractDistribution:def sample(self):raise NotImplementedError()def mode(self):raise NotImplementedError()
  • 作用:定义概率分布的抽象接口,规定所有具体分布类必须实现两个核心方法:
    • sample():从分布中采样一个样本。
    • mode():返回分布的众数(概率密度最高的点)。
  • 设计目的:统一不同分布的调用接口,方便后续在模型中替换或使用不同分布(如高斯分布、狄拉克分布等)。

2. 狄拉克分布 DiracDistribution

class DiracDistribution(AbstractDistribution):def __init__(self, value):self.value = value  # 分布的唯一确定性值def sample(self):return self.value  # 采样结果就是固定值def mode(self):return self.value  # 众数也是固定值
  • 数学背景:狄拉克分布(Dirac delta distribution)是一种确定性分布,所有概率质量集中在单个点上(即value)。
  • 方法解析
    • __init__:接收一个固定值value,作为分布的唯一可能值。
    • sample()mode():均返回value,因为狄拉克分布的采样结果和众数都是这个固定值。
  • 应用场景:用于确定性模型(非概率模型),或作为概率模型的特殊情况(当分布方差为0时)。

3. 对角高斯分布 DiagonalGaussianDistribution

这是代码的核心类,实现了各维度独立的多元高斯分布(即协方差矩阵为对角矩阵),适用于高维潜在变量(如VAE中的潜在空间)。

3.1 初始化方法 __init__
def __init__(self, parameters, deterministic=False):self.parameters = parameters  # 输入参数(包含均值和对数方差)# 将参数沿通道维度(dim=1)分为均值(mean)和对数方差(logvar)self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)# 截断logvar到[-30, 20],避免数值溢出(防止std过大或过小)self.logvar = torch.clamp(self.logvar, -30.0, 20.0)self.deterministic = deterministic  # 是否为确定性模式(方差为0)# 计算标准差(std)和方差(var):std = exp(0.5 * logvar),var = exp(logvar)self.std = torch.exp(0.5 * self.logvar)self.var = torch.exp(self.logvar)# 若为确定性模式,强制方差和标准差为0(退化为狄拉克分布)if self.deterministic:self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
  • 核心参数parameters是一个张量,包含分布的均值和对数方差(沿dim=1拼接,例如形状为(batch, 2*z_dim, ...),分块后得到meanlogvar,各为(batch, z_dim, ...))。
  • 数值稳定性logvar被截断到[-30, 20],因为:
    • logvar > 20std = exp(10) ≈ 2万,可能导致采样结果过大;
    • logvar < -30std = exp(-15) ≈ 3e-7,接近0,可能导致梯度消失。
  • 确定性模式:当deterministic=True时,方差和标准差被置为0,分布退化为以mean为中心的狄拉克分布(用于推理时避免随机性)。
3.2 采样方法 sample()
def sample(self):# 重参数化采样:mean + std * 标准正态噪声(确保梯度可传)x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)return x
  • 核心思想:使用重参数化技巧(Reparameterization Trick),将采样过程表示为mean + std * ε(其中ε ~ N(0,1)),使采样操作可导(避免因随机采样导致的梯度断裂)。
  • 适用场景:训练时从潜在分布中采样,引入随机性以符合VAE的概率建模要求。
3.3 KL散度计算 kl()

KL散度(Kullback-Leibler Divergence)用于衡量两个分布的差异,是VAE中约束潜在分布的核心损失。

def kl(self, other=None):if self.deterministic:return torch.Tensor([0.0])  # 确定性分布的KL散度为0else:if other is None:# 与标准正态分布N(0,1)的KL散度return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,dim=[1, 2, 3],  # 对通道、时间、频率等维度求均值)else:# 与另一个高斯分布other的KL散度return 0.5 * torch.mean(torch.pow(self.mean - other.mean, 2) / other.var+ self.var / other.var- 1.0- self.logvar+ other.logvar,dim=[1, 2, 3],)
  • 数学公式
    两个高斯分布N(μ₁, σ₁²)N(μ₂, σ₂²)的KL散度为:
    KL = 0.5 * [ (μ₁-μ₂)²/σ₂² + σ₁²/σ₂² - 1 - log(σ₁²/σ₂²) ]
    other=None时,默认与标准正态分布N(0,1)比较,公式简化为:
    KL = 0.5 * (μ₁² + σ₁² - 1 - logσ₁²)
  • 维度处理dim=[1,2,3]表示对潜在变量的通道、时间、频率等维度求和后取均值(适用于图像、音频等多维特征)。
3.4 负对数似然 nll()
def nll(self, sample, dims=[1, 2, 3]):if self.deterministic:return torch.Tensor([0.0])  # 确定性分布的负对数似然为0logtwopi = np.log(2.0 * np.pi)  # log(2π)是常数# 高斯分布的负对数似然公式return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,dim=dims,)
  • 数学背景:若样本x服从高斯分布N(μ, σ²),则其对数似然为:
    log p(x) = -0.5 * [log(2π) + logσ² + (x-μ)²/σ²]
    负对数似然(NLL)为上述值的相反数。
  • 作用:衡量样本sample与当前分布的匹配程度,常用于生成模型的损失计算。
3.5 众数 mode()
def mode(self):return self.mean  # 高斯分布的众数等于均值
  • 高斯分布是单峰分布,其概率密度最高的点(众数)就是均值mean
  • 应用场景:推理时生成确定性结果(如不采样,直接用均值作为潜在变量)。

4. 高斯KL散度工具函数 normal_kl

def normal_kl(mean1, logvar1, mean2, logvar2):"""计算两个高斯分布N(mean1, exp(logvar1))和N(mean2, exp(logvar2))的KL散度"""tensor = Nonefor obj in (mean1, logvar1, mean2, logvar2):if isinstance(obj, torch.Tensor):tensor = objbreakassert tensor is not None, "至少一个参数必须是Tensor"# 确保logvar1和logvar2是Tensor(支持标量输入)logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)for x in (logvar1, logvar2)]# KL散度公式(逐元素计算)return 0.5 * (-1.0+ logvar2- logvar1+ torch.exp(logvar1 - logvar2)+ ((mean1 - mean2) **2) * torch.exp(-logvar2))

-** 功能 :通用的高斯KL散度计算函数,支持输入为张量或标量(通过广播机制适配不同形状)。
-
DiagonalGaussianDistribution.kl 的区别
后者是类方法,针对当前分布与另一个分布(或标准正态)计算KL散度并在指定维度求均值;
前者是独立函数,逐元素计算KL散度,不做均值或求和,更灵活(如用于需要保留空间维度的场景)。
-
来源 **:注释提到来自openai/guided-diffusion,常用于扩散模型中的损失计算。

###** 总结 **- 这段代码围绕概率分布的核心操作(采样、众数、KL散度、负对数似然)构建了工具类,尤其聚焦于对角高斯分布——这是生成式模型(如VAE、扩散模型)中潜在变量建模的核心组件。

  • 关键设计考虑:
    • 接口统一:通过AbstractDistribution确保不同分布的调用一致性;
    • 数值稳定性:对logvar截断,避免标准差过大或过小;
    • 灵活性:支持确定性模式(推理用)和概率模式(训练用),并提供通用的KL散度计算函数。

这些工具为后续构建生成模型(如前文的AutoencoderKL)提供了概率建模的基础,特别是在计算损失(如KL损失、重建损失)和潜在变量采样时不可或缺。

modules.py

先给出完整的代码:

# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrangefrom audioldm.utils import instantiate_from_config
from audioldm.latent_diffusion.attention import LinearAttentiondef get_timestep_embedding(timesteps, embedding_dim):"""This matches the implementation in Denoising Diffusion Probabilistic Models:From Fairseq.Build sinusoidal embeddings.This matches the implementation in tensor2tensor, but differs slightlyfrom the description in Section 3.5 of "Attention Is All You Need"."""assert len(timesteps.shape) == 1half_dim = embedding_dim // 2emb = math.log(10000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)emb = emb.to(device=timesteps.device)emb = timesteps.float()[:, None] * emb[None, :]emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)if embedding_dim % 2 == 1:  # zero pademb = torch.nn.functional.pad(emb, (0, 1, 0, 0))return embdef nonlinearity(x):# swishreturn x * torch.sigmoid(x)def Normalize(in_channels, num_groups=32):return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)class Upsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")if self.with_conv:x = self.conv(x)return xclass UpsampleTimeStride4(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=1, padding=2)def forward(self, x):x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")if self.with_conv:x = self.conv(x)return xclass Downsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:# Do time downsampling here# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)def forward(self, x):if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)return xclass DownsampleTimeStride4(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:# Do time downsampling here# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1)def forward(self, x):if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))return xclass ResnetBlock(nn.Module):def __init__(self,*,in_channels,out_channels=None,conv_shortcut=False,dropout,temb_channels=512,):super().__init__()self.in_channels = in_channelsout_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channelsself.use_conv_shortcut = conv_shortcutself.norm1 = Normalize(in_channels)self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)if temb_channels > 0:self.temb_proj = torch.nn.Linear(temb_channels, out_channels)self.norm2 = Normalize(out_channels)self.dropout = torch.nn.Dropout(dropout)self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)if self.in_channels != self.out_channels:if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)else:self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)def forward(self, x, temb):h = xh = self.norm1(h)h = nonlinearity(h)h = self.conv1(h)if temb is not None:h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]h = self.norm2(h)h = nonlinearity(h)h = self.dropout(h)h = self.conv2(h)if self.in_channels != self.out_channels:if self.use_conv_shortcut:x = self.conv_shortcut(x)else:x = self.nin_shortcut(x)return x + hclass LinAttnBlock(LinearAttention):"""to match AttnBlock usage"""def __init__(self, in_channels):super().__init__(dim=in_channels, heads=1, dim_head=in_channels)class AttnBlock(nn.Module):def __init__(self, in_channels):super().__init__()self.in_channels = in_channelsself.norm = Normalize(in_channels)self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)def forward(self, x):h_ = xh_ = self.norm(h_)q = self.q(h_)k = self.k(h_)v = self.v(h_)# compute attentionb, c, h, w = q.shapeq = q.reshape(b, c, h * w).contiguous()q = q.permute(0, 2, 1).contiguous()  # b,hw,ck = k.reshape(b, c, h * w).contiguous()  # b,c,hww_ = torch.bmm(q, k).contiguous()  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]w_ = w_ * (int(c) ** (-0.5))w_ = torch.nn.functional.softmax(w_, dim=2)# attend to valuesv = v.reshape(b, c, h * w).contiguous()w_ = w_.permute(0, 2, 1).contiguous()  # b,hw,hw (first hw of k, second of q)h_ = torch.bmm(v, w_).contiguous()  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]h_ = h_.reshape(b, c, h, w).contiguous()h_ = self.proj_out(h_)return x + h_def make_attn(in_channels, attn_type="vanilla"):assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")if attn_type == "vanilla":return AttnBlock(in_channels)elif attn_type == "none":return nn.Identity(in_channels)else:return LinAttnBlock(in_channels)class Model(nn.Module):def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,use_timestep=True,use_linear_attn=False,attn_type="vanilla",):super().__init__()if use_linear_attn:attn_type = "linear"self.ch = chself.temb_ch = self.ch * 4self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.use_timestep = use_timestepif self.use_timestep:# timestep embeddingself.temb = nn.Module()self.temb.dense = nn.ModuleList([torch.nn.Linear(self.ch, self.temb_ch),torch.nn.Linear(self.temb_ch, self.temb_ch),])# downsamplingself.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)curr_res = resolutionin_ch_mult = (1,) + tuple(ch_mult)self.down = nn.ModuleList()for i_level in range(self.num_resolutions):block = nn.ModuleList()attn = nn.ModuleList()block_in = ch * in_ch_mult[i_level]block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))down = nn.Module()down.block = blockdown.attn = attnif i_level != self.num_resolutions - 1:down.downsample = Downsample(block_in, resamp_with_conv)curr_res = curr_res // 2self.down.append(down)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# upsamplingself.up = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()attn = nn.ModuleList()block_out = ch * ch_mult[i_level]skip_in = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks + 1):if i_block == self.num_res_blocks:skip_in = ch * in_ch_mult[i_level]block.append(ResnetBlock(in_channels=block_in + skip_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))up = nn.Module()up.block = blockup.attn = attnif i_level != 0:up.upsample = Upsample(block_in, resamp_with_conv)curr_res = curr_res * 2self.up.insert(0, up)  # prepend to get consistent order# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)def forward(self, x, t=None, context=None):# assert x.shape[2] == x.shape[3] == self.resolutionif context is not None:# assume aligned context, cat along channel axisx = torch.cat((x, context), dim=1)if self.use_timestep:# timestep embeddingassert t is not Nonetemb = get_timestep_embedding(t, self.ch)temb = self.temb.dense[0](temb)temb = nonlinearity(temb)temb = self.temb.dense[1](temb)else:temb = None# downsamplinghs = [self.conv_in(x)]for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# middleh = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# upsamplingfor i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:h = self.up[i_level].upsample(h)# endh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return hdef get_last_layer(self):return self.conv_out.weightclass Encoder(nn.Module):def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,z_channels,double_z=True,use_linear_attn=False,attn_type="vanilla",downsample_time_stride4_levels=[],**ignore_kwargs,):super().__init__()if use_linear_attn:attn_type = "linear"self.ch = chself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.downsample_time_stride4_levels = downsample_time_stride4_levelsif len(self.downsample_time_stride4_levels) > 0:assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ("The level to perform downsample 4 operation need to be smaller than the total resolution number %s"% str(self.num_resolutions))# downsamplingself.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)curr_res = resolutionin_ch_mult = (1,) + tuple(ch_mult)self.in_ch_mult = in_ch_multself.down = nn.ModuleList()for i_level in range(self.num_resolutions):block = nn.ModuleList()attn = nn.ModuleList()block_in = ch * in_ch_mult[i_level]block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))down = nn.Module()down.block = blockdown.attn = attnif i_level != self.num_resolutions - 1:if i_level in self.downsample_time_stride4_levels:down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)else:down.downsample = Downsample(block_in, resamp_with_conv)curr_res = curr_res // 2self.down.append(down)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in,2 * z_channels if double_z else z_channels,kernel_size=3,stride=1,padding=1,)def forward(self, x):# timestep embeddingtemb = None# downsamplinghs = [self.conv_in(x)]for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# middleh = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# endh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return hclass Decoder(nn.Module):def __init__(self,*,ch,out_ch,ch_mult=(1, 2, 4, 8),num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,in_channels,resolution,z_channels,give_pre_end=False,tanh_out=False,use_linear_attn=False,downsample_time_stride4_levels=[],attn_type="vanilla",**ignorekwargs,):super().__init__()if use_linear_attn:attn_type = "linear"self.ch = chself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channelsself.give_pre_end = give_pre_endself.tanh_out = tanh_outself.downsample_time_stride4_levels = downsample_time_stride4_levelsif len(self.downsample_time_stride4_levels) > 0:assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ("The level to perform downsample 4 operation need to be smaller than the total resolution number %s"% str(self.num_resolutions))# compute in_ch_mult, block_in and curr_res at lowest resin_ch_mult = (1,) + tuple(ch_mult)block_in = ch * ch_mult[self.num_resolutions - 1]curr_res = resolution // 2 ** (self.num_resolutions - 1)self.z_shape = (1, z_channels, curr_res, curr_res)# print("Working with z of shape {} = {} dimensions.".format(# self.z_shape, np.prod(self.z_shape)))# z to block_inself.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)# middleself.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)self.mid.block_2 = ResnetBlock(in_channels=block_in,out_channels=block_in,temb_channels=self.temb_ch,dropout=dropout,)# upsamplingself.up = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()attn = nn.ModuleList()block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks + 1):block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outif curr_res in attn_resolutions:attn.append(make_attn(block_in, attn_type=attn_type))up = nn.Module()up.block = blockup.attn = attnif i_level != 0:if i_level - 1 in self.downsample_time_stride4_levels:up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)else:up.upsample = Upsample(block_in, resamp_with_conv)curr_res = curr_res * 2self.up.insert(0, up)  # prepend to get consistent order# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)def forward(self, z):# assert z.shape[1:] == self.z_shape[1:]self.last_z_shape = z.shape# timestep embeddingtemb = None# z to block_inh = self.conv_in(z)# middleh = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# upsamplingfor i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](h, temb)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:h = self.up[i_level].upsample(h)# endif self.give_pre_end:return hh = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)if self.tanh_out:h = torch.tanh(h)return hclass SimpleDecoder(nn.Module):def __init__(self, in_channels, out_channels, *args, **kwargs):super().__init__()self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),ResnetBlock(in_channels=in_channels,out_channels=2 * in_channels,temb_channels=0,dropout=0.0,),ResnetBlock(in_channels=2 * in_channels,out_channels=4 * in_channels,temb_channels=0,dropout=0.0,),ResnetBlock(in_channels=4 * in_channels,out_channels=2 * in_channels,temb_channels=0,dropout=0.0,),nn.Conv2d(2 * in_channels, in_channels, 1),Upsample(in_channels, with_conv=True),])# endself.norm_out = Normalize(in_channels)self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):for i, layer in enumerate(self.model):if i in [1, 2, 3]:x = layer(x, None)else:x = layer(x)h = self.norm_out(x)h = nonlinearity(h)x = self.conv_out(h)return xclass UpsampleDecoder(nn.Module):def __init__(self,in_channels,out_channels,ch,num_res_blocks,resolution,ch_mult=(2, 2),dropout=0.0,):super().__init__()# upsamplingself.temb_ch = 0self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksblock_in = in_channelscurr_res = resolution // 2 ** (self.num_resolutions - 1)self.res_blocks = nn.ModuleList()self.upsample_blocks = nn.ModuleList()for i_level in range(self.num_resolutions):res_block = []block_out = ch * ch_mult[i_level]for i_block in range(self.num_res_blocks + 1):res_block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout,))block_in = block_outself.res_blocks.append(nn.ModuleList(res_block))if i_level != self.num_resolutions - 1:self.upsample_blocks.append(Upsample(block_in, True))curr_res = curr_res * 2# endself.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):# upsamplingh = xfor k, i_level in enumerate(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.res_blocks[i_level][i_block](h, None)if i_level != self.num_resolutions - 1:h = self.upsample_blocks[k](h)h = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)return hclass LatentRescaler(nn.Module):def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):super().__init__()# residual block, interpolate, residual blockself.factor = factorself.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,)for _ in range(depth)])self.attn = AttnBlock(mid_channels)self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,)for _ in range(depth)])self.conv_out = nn.Conv2d(mid_channels,out_channels,kernel_size=1,)def forward(self, x):x = self.conv_in(x)for block in self.res_block1:x = block(x, None)x = torch.nn.functional.interpolate(x,size=(int(round(x.shape[2] * self.factor)),int(round(x.shape[3] * self.factor)),),)x = self.attn(x).contiguous()for block in self.res_block2:x = block(x, None)x = self.conv_out(x)return xclass MergedRescaleEncoder(nn.Module):def __init__(self,in_channels,ch,resolution,out_ch,num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,ch_mult=(1, 2, 4, 8),rescale_factor=1.0,rescale_module_depth=1,):super().__init__()intermediate_chn = ch * ch_mult[-1]self.encoder = Encoder(in_channels=in_channels,num_res_blocks=num_res_blocks,ch=ch,ch_mult=ch_mult,z_channels=intermediate_chn,double_z=False,resolution=resolution,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,out_ch=None,)self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=intermediate_chn,mid_channels=intermediate_chn,out_channels=out_ch,depth=rescale_module_depth,)def forward(self, x):x = self.encoder(x)x = self.rescaler(x)return xclass MergedRescaleDecoder(nn.Module):def __init__(self,z_channels,out_ch,resolution,num_res_blocks,attn_resolutions,ch,ch_mult=(1, 2, 4, 8),dropout=0.0,resamp_with_conv=True,rescale_factor=1.0,rescale_module_depth=1,):super().__init__()tmp_chn = z_channels * ch_mult[-1]self.decoder = Decoder(out_ch=out_ch,z_channels=tmp_chn,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,in_channels=None,num_res_blocks=num_res_blocks,ch_mult=ch_mult,resolution=resolution,ch=ch,)self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=z_channels,mid_channels=tmp_chn,out_channels=tmp_chn,depth=rescale_module_depth,)def forward(self, x):x = self.rescaler(x)x = self.decoder(x)return xclass Upsampler(nn.Module):def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):super().__init__()assert out_size >= in_sizenum_blocks = int(np.log2(out_size // in_size)) + 1factor_up = 1.0 + (out_size % in_size)print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")self.rescaler = LatentRescaler(factor=factor_up,in_channels=in_channels,mid_channels=2 * in_channels,out_channels=in_channels,)self.decoder = Decoder(out_ch=out_channels,resolution=out_size,z_channels=in_channels,num_res_blocks=2,attn_resolutions=[],in_channels=None,ch=in_channels,ch_mult=[ch_mult for _ in range(num_blocks)],)def forward(self, x):x = self.rescaler(x)x = self.decoder(x)return xclass Resize(nn.Module):def __init__(self, in_channels=None, learned=False, mode="bilinear"):super().__init__()self.with_conv = learnedself.mode = modeif self.with_conv:print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")raise NotImplementedError()assert in_channels is not None# no asymmetric padding in torch conv, must do it ourselvesself.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)def forward(self, x, scale_factor=1.0):if scale_factor == 1.0:return xelse:x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)return xclass FirstStagePostProcessor(nn.Module):def __init__(self,ch_mult: list,in_channels,pretrained_model: nn.Module = None,reshape=False,n_channels=None,dropout=0.0,pretrained_config=None,):super().__init__()if pretrained_config is None:assert (pretrained_model is not None), 'Either "pretrained_model" or "pretrained_config" must not be None'self.pretrained_model = pretrained_modelelse:assert (pretrained_config is not None), 'Either "pretrained_model" or "pretrained_config" must not be None'self.instantiate_pretrained(pretrained_config)self.do_reshape = reshapeif n_channels is None:n_channels = self.pretrained_model.encoder.chself.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, stride=1, padding=1)blocks = []downs = []ch_in = n_channelsfor m in ch_mult:blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout))ch_in = m * n_channelsdowns.append(Downsample(ch_in, with_conv=False))self.model = nn.ModuleList(blocks)self.downsampler = nn.ModuleList(downs)def instantiate_pretrained(self, config):model = instantiate_from_config(config)self.pretrained_model = model.eval()# self.pretrained_model.train = Falsefor param in self.pretrained_model.parameters():param.requires_grad = False@torch.no_grad()def encode_with_pretrained(self, x):c = self.pretrained_model.encode(x)if isinstance(c, DiagonalGaussianDistribution):c = c.mode()return cdef forward(self, x):z_fs = self.encode_with_pretrained(x)z = self.proj_norm(z_fs)z = self.proj(z)z = nonlinearity(z)for submodel, downmodel in zip(self.model, self.downsampler):z = submodel(z, temb=None)z = downmodel(z)if self.do_reshape:z = rearrange(z, "b c h w -> b (h w) c")return z

这段代码是基于PyTorch实现的一套生成式模型组件,主要用于扩散模型(Diffusion Model)和变分自编码器(VAE)等生成式任务,尤其适配音频或图像等具有时空结构的数据(从audioldm相关引用可推测更偏向音频处理)。代码包含基础网络模块、扩散模型主网络、编码器、解码器及各类辅助模块,整体设计围绕卷积神经网络(CNN)、残差连接、注意力机制和扩散过程的时间步处理展开。

一、核心工具函数与基础模块

核心工具函数与基础模块是构建深层生成模型(如扩散模型、VAE)的“乐高积木”,负责处理特征变换、分辨率调整、梯度传播和长距离依赖捕捉等基础任务。这些模块的设计直接影响模型的性能、稳定性和适用性(如适配音频/图像等不同数据类型)。以下是详细解析:

1. 时间步嵌入(get_timestep_embedding

扩散模型的核心组件,用于将离散的“时间步”(扩散过程中添加噪声的步骤)转换为连续向量,使模型能学习不同时间步的降噪规律。

def get_timestep_embedding(timesteps, embedding_dim):assert len(timesteps.shape) == 1  # 输入为1D张量(batch_size,)half_dim = embedding_dim // 2# 生成指数衰减的频率(类似Transformer位置编码)emb = math.log(10000) / (half_dim - 1)  # 频率缩放因子emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)  # 频率向量emb = emb.to(device=timesteps.device)  # 转移到与时间步相同的设备# 时间步与频率相乘,得到基础嵌入emb = timesteps.float()[:, None] * emb[None, :]  # 形状:(batch_size, half_dim)# 拼接正弦和余弦分量,增强表达能力emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)  # 形状:(batch_size, embedding_dim)# 若嵌入维度为奇数,补零以匹配维度if embedding_dim % 2 == 1:emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))  # 补最后一维return emb

关键细节

  • 原理类似Transformer的位置编码,通过正弦/余弦函数将离散时间步映射到连续空间,确保“时间距离”与“嵌入向量距离”一致(如时间步1和2的嵌入比1和10更相似)。
  • 输出维度为embedding_dim,后续会通过全连接层进一步加工,融入模型的每个残差块,使模型能感知当前处于扩散过程的哪个阶段。

2. 非线性激活与归一化

2.1 Swish激活函数(nonlinearity
def nonlinearity(x):# swish激活:x * sigmoid(x)return x * torch.sigmoid(x)
  • 优点:相比ReLU,Swish在负值区域不直接截断,而是平滑衰减,有助于缓解梯度消失,尤其适合深层网络。
  • 无需额外参数,计算高效,在生成模型中广泛替代ReLU。
2.2 组归一化(Normalize
def Normalize(in_channels, num_groups=32):return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
  • 原理:将输入特征的通道分成num_groups组,每组内独立计算均值和方差并归一化,最后通过可学习参数(affine=True)缩放和平移。
  • 优势:不受批量大小影响(避免BatchNorm在小批量时的不稳定),适合生成模型(通常批量较小)。
  • 音频/图像特征通常通道数较多(如256、512),num_groups=32是经验值,平衡计算效率和归一化效果。

3. 上采样与下采样模块

用于调整特征图的空间/时间分辨率(如音频频谱的“时间步×频率 bins”维度),是多尺度特征学习的核心。

3.1 普通上采样(Upsample
class Upsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:  # 上采样后是否用卷积调整特征self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):# 近邻插值上采样,尺度放大2倍(适用于空间/时间维度)x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")if self.with_conv:  # 上采样后用3x3卷积消除锯齿,增强特征表达x = self.conv(x)return x
  • 作用:将特征图分辨率翻倍(如28×28→56×56),with_conv=True时通过卷积整合插值后的特征,避免高频噪声。
3.2 音频专用上采样(UpsampleTimeStride4
class UpsampleTimeStride4(nn.Module):def forward(self, x):# 时间维度放大4倍,频率维度放大2倍(适配音频频谱特性)x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")if self.with_conv:x = self.conv(x)  # 5x5卷积更适合捕捉音频的长时相关性return x
  • 设计动机:音频频谱的时间维度(如毫秒级)比频率维度(如Hz)需要更大的缩放倍数(4倍 vs 2倍),以恢复时间连续性。
3.3 普通下采样(Downsample
class Downsample(nn.Module):def __init__(self, in_channels, with_conv):super().__init__()self.with_conv = with_convif self.with_conv:# 3x3卷积+步长2实现下采样(替代池化,保留更多特征)self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)def forward(self, x):if self.with_conv:# 不对称 padding(右、下各补1),确保下采样后尺寸为原图1/2pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)x = self.conv(x)else:# 平均池化下采样(更简单,但可能丢失细节)x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)return x
  • 作用:将特征图分辨率减半(如56×56→28×28),with_conv=True时用卷积下采样,比池化保留更多空间结构信息。
3.4 音频专用下采样(DownsampleTimeStride4
class DownsampleTimeStride4(nn.Module):def forward(self, x):if self.with_conv:pad = (0, 1, 0, 1)x = torch.nn.functional.pad(x, pad, mode="constant", value=0)# 时间步长4,频率步长2,适配音频下采样需求x = self.conv(x)  # 5x5卷积捕捉更大范围特征else:x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))return x
  • 应用:与UpsampleTimeStride4对应,在编码阶段压缩音频频谱的时间维度(4倍)和频率维度(2倍),减少计算量。

4. 残差块(ResnetBlock

生成式模型的基础构建块,通过残差连接解决深层网络的梯度消失问题,同时支持融入时间步嵌入(扩散模型专用)。

class ResnetBlock(nn.Module):def __init__(self,*,in_channels,out_channels=None,conv_shortcut=False,dropout,temb_channels=512,  # 时间步嵌入维度):super().__init__()self.in_channels = in_channelsout_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channelsself.use_conv_shortcut = conv_shortcut  # 残差连接是否用卷积调整通道# 第一层:归一化+激活+卷积self.norm1 = Normalize(in_channels)self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)# 时间步嵌入投影(扩散模型用)if temb_channels > 0:self.temb_proj = torch.nn.Linear(temb_channels, out_channels)# 第二层:归一化+激活+dropout+卷积self.norm2 = Normalize(out_channels)self.dropout = torch.nn.Dropout(dropout)self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)# 残差连接(输入输出通道不同时需要调整)if self.in_channels != self.out_channels:if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)else:self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)  # 1x1卷积调整通道def forward(self, x, temb):h = x  # 残差分支# 第一层处理h = self.norm1(h)  # 组归一化h = nonlinearity(h)  # Swish激活h = self.conv1(h)  # 3x3卷积升维/保持维度# 融入时间步嵌入(仅扩散模型,temb不为None时)if temb is not None:# 时间步嵌入先激活,再投影到当前通道数,最后广播到特征图维度h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]# 第二层处理h = self.norm2(h)h = nonlinearity(h)h = self.dropout(h)  # 防止过拟合h = self.conv2(h)# 残差连接(输入x与残差分支h相加)if self.in_channels != self.out_channels:# 若通道不同,先调整x的通道x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)return x + h  # 残差相加,保留原始特征并叠加新特征

核心作用

  • 残差连接(x + h)使梯度能直接从输出流回输入,解决深层网络梯度消失问题。
  • 支持融入时间步嵌入(temb),让每个残差块都能感知当前扩散阶段,是扩散模型的关键设计。
  • 两层卷积+归一化+激活的结构,能有效提取局部特征,同时通过dropout增强泛化能力。

5. 注意力模块

捕捉特征图中的长距离依赖(如音频的长时相关性、图像的全局结构),分为普通注意力和线性注意力。

5.1 普通自注意力(AttnBlock
class AttnBlock(nn.Module):def __init__(self, in_channels):super().__init__()self.in_channels = in_channelsself.norm = Normalize(in_channels)  # 归一化# 1x1卷积生成查询(q)、键(k)、值(v)self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)# 注意力输出投影self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)def forward(self, x):h_ = x  # 注意力分支h_ = self.norm(h_)  # 归一化# 生成q、k、v(形状:(batch, channels, height, width))q = self.q(h_)k = self.k(h_)v = self.v(h_)# 展平空间维度(height×width → 序列长度N)b, c, h, w = q.shapeq = q.reshape(b, c, h * w).contiguous()  # (b, c, N)q = q.permute(0, 2, 1).contiguous()  # (b, N, c):每个位置的查询向量k = k.reshape(b, c, h * w).contiguous()  # (b, c, N):每个位置的键向量# 计算注意力权重:q与k的点积,缩放避免数值过大w_ = torch.bmm(q, k).contiguous()  # (b, N, N):每个位置对其他位置的注意力w_ = w_ * (int(c) ** -0.5)  # 缩放因子:通道数的平方根倒数w_ = torch.nn.functional.softmax(w_, dim=2)  # 归一化权重# 注意力加权求和:用权重w_对v加权v = v.reshape(b, c, h * w).contiguous()  # (b, c, N):每个位置的值向量w_ = w_.permute(0, 2, 1).contiguous()  # (b, N, N):转置权重用于加权h_ = torch.bmm(v, w_).contiguous()  # (b, c, N):加权后的值h_ = h_.reshape(b, c, h, w).contiguous()  # 恢复空间维度# 投影输出+残差连接h_ = self.proj_out(h_)return x + h_  # 原始特征与注意力特征相加

关键细节

  • 本质是“空间自注意力”,将特征图的每个空间位置(如音频频谱的每个“时间-频率点”)视为序列元素,计算位置间的依赖关系。
  • 复杂度为O(N2)O(N^2)O(N2)NNN为空间位置数),适合低分辨率特征图(如32×32),高分辨率时计算成本过高。
5.2 线性注意力(LinAttnBlock
class LinAttnBlock(LinearAttention):"""to match AttnBlock usage"""def __init__(self, in_channels):super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
  • 继承自LinearAttention(外部实现),通过核函数近似将注意力复杂度从O(N2)O(N^2)O(N2)降至O(N)O(N)O(N),适合高分辨率特征图。
  • 牺牲部分表达能力换取效率,在对长距离依赖要求不极致的场景(如音频局部结构)中表现良好。
5.3 注意力选择器(make_attn
def make_attn(in_channels, attn_type="vanilla"):assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"if attn_type == "vanilla":return AttnBlock(in_channels)  # 普通注意力elif attn_type == "none":return nn.Identity(in_channels)  # 无注意力(仅残差连接)else:return LinAttnBlock(in_channels)  # 线性注意力
  • 灵活选择注意力类型,根据任务需求(精度vs效率)切换,增强模型适应性。

总结

这些基础模块是生成模型的“基础设施”:

  • 时间步嵌入让模型感知扩散阶段;
  • Swish+GroupNorm确保特征变换的稳定性和表达力;
  • 上/下采样实现多尺度特征学习,适配音频/图像的分辨率需求;
  • 残差块解决深层网络训练难题,同时融入时间信息;
  • 注意力模块捕捉长距离依赖,增强全局特征建模能力。

它们的组合构成了扩散模型、VAE等复杂生成模型的主体框架,其设计细节(如音频专用的采样策略)直接影响模型对特定数据类型的适配性。

二、核心网络结构

核心网络结构是这套代码的“主体框架”,包括扩散模型主网络(Model)、编码器(Encoder)和解码器(Decoder),分别对应生成式模型的三个核心功能:** 降噪生成 数据压缩编码 潜在变量解码重建**。这些结构基于前述基础模块(残差块、注意力、采样层等)搭建,针对时空数据(如音频频谱)的特性优化,下面详细解析:

1. 扩散模型主网络(Model

扩散模型(如DDPM)的核心是“降噪网络”,通过学习从含噪声数据中恢复原始数据的规律,实现生成。Model类基于U-Net结构设计,支持时间步感知和注意力机制,是扩散过程的核心执行者。

1.1 初始化参数(__init__
def __init__(self,*,ch,  # 基础通道数out_ch,  # 输出通道数(与输入数据通道一致)ch_mult=(1, 2, 4, 8),  # 各分辨率阶段的通道倍增因子num_res_blocks,  # 每个分辨率阶段的残差块数量attn_resolutions,  # 需要添加注意力的分辨率dropout=0.0,  # dropout概率resamp_with_conv=True,  # 采样时是否用卷积(而非池化)in_channels,  # 输入数据通道数resolution,  # 输入数据分辨率(如64x64)use_timestep=True,  # 是否使用时间步嵌入(扩散模型必须)use_linear_attn=False,  # 是否使用线性注意力attn_type="vanilla",  # 注意力类型(普通/线性/无)
):
  • 核心参数解析
    • ch_mult:控制网络深度,如(1,2,4,8)表示4个分辨率阶段,通道数依次为ch×1ch×2ch×4ch×8(逐步加深)。
    • attn_resolutions:指定在哪些分辨率下添加注意力(如(32, 16)表示在32×32和16×16分辨率时使用注意力)。
    • use_timestep:扩散模型的标志,启用时间步嵌入以区分不同降噪阶段。
1.2 网络结构与前向传播(forward

网络整体遵循“下采样→中间层→上采样”的U-Net流程,结合时间步嵌入和跳连(Skip Connection)融合多尺度特征。

def forward(self, x, t=None, context=None):# 步骤1:处理条件输入(可选,如文本引导生成)if context is not None:x = torch.cat((x, context), dim=1)  # 沿通道维度拼接输入与条件# 步骤2:生成时间步嵌入(扩散模型核心)if self.use_timestep:assert t is not None  # 必须传入时间步ttemb = get_timestep_embedding(t, self.ch)  # 生成基础嵌入temb = self.temb.dense[0](temb)  # 全连接层加工temb = nonlinearity(temb)temb = self.temb.dense[1](temb)  # 最终时间步嵌入(维度:temb_ch=ch×4)else:temb = None  # 非扩散模型不使用# 步骤3:下采样(Downsampling):逐步降低分辨率,提取高层特征hs = [self.conv_in(x)]  # 初始卷积:将输入通道转为基础通道chfor i_level in range(self.num_resolutions):  # 遍历每个分辨率阶段# 每个阶段包含多个残差块for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)  # 残差块处理(融入时间步)if len(self.down[i_level].attn) > 0:  # 若当前分辨率需注意力h = self.down[i_level].attn[i_block](h)  # 注意力增强特征hs.append(h)  # 保存特征用于后续跳连# 阶段结束时下采样(最后一个阶段不采样)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))  # 分辨率减半# 步骤4:中间层(Middle):处理最深层特征h = hs[-1]  # 取最深层特征h = self.mid.block_1(h, temb)  # 残差块h = self.mid.attn_1(h)  # 注意力h = self.mid.block_2(h, temb)  # 残差块# 步骤5:上采样(Upsampling):逐步恢复分辨率,融合跳连特征for i_level in reversed(range(self.num_resolutions)):  # 逆序遍历分辨率阶段# 每个阶段包含多个残差块,融合下采样时的跳连特征for i_block in range(self.num_res_blocks + 1):# 拼接当前特征与下采样阶段的对应特征(跳连)h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)if len(self.up[i_level].attn) > 0:  # 注意力增强h = self.up[i_level].attn[i_block](h)# 阶段结束时上采样(第一个阶段不采样)if i_level != 0:h = self.up[i_level].upsample(h)  # 分辨率翻倍# 步骤6:输出层:预测噪声或原始数据h = self.norm_out(h)  # 归一化h = nonlinearity(h)  # 激活h = self.conv_out(h)  # 卷积输出(通道数=out_ch)return h
1.3 核心设计亮点
  • U-Net结构:下采样(down)通过多层残差块和下采样层逐步压缩空间维度(如64→32→16→8),提取抽象特征;上采样(up)通过上采样层和残差块逐步恢复维度,同时融合下采样阶段的跳连特征(hs.pop()),保留细节信息。
  • 时间步嵌入:时间步t通过get_timestep_embedding转为向量后,融入每个残差块(h += self.temb_proj(...)),使模型能学习不同噪声水平(不同t)的降噪规律。
  • 条件生成支持:通过context参数接收条件信号(如文本嵌入),与输入数据拼接后送入网络,实现条件生成(如“生成钢琴声”)。

2. 编码器(Encoder

编码器的作用是将原始数据(如音频频谱)压缩为潜在空间变量,通常用于VAE或作为扩散模型的“第一阶段”(将高维数据映射到低维潜在空间,降低生成难度)。其结构类似Model的下采样部分,但无时间步嵌入(非扩散过程)。

2.1 初始化参数(__init__

Model参数基础上新增:

def __init__(self,...,z_channels,  # 潜在变量通道数double_z=True,  # 是否输出两倍通道(均值+对数方差,用于VAE的高斯分布)downsample_time_stride4_levels=[],  # 哪些阶段使用音频专用下采样(时间步长4)...
):
  • double_z=True是VAE的典型设计:输出通道数为2×z_channels,前半为潜在变量的均值(mean),后半为对数方差(logvar),用于构建对角高斯分布(见前文DiagonalGaussianDistribution)。
2.2 前向传播(forward
def forward(self, x):temb = None  # 编码器无时间步(非扩散过程)# 下采样:与Model的下采样流程一致,但无时间步嵌入hs = [self.conv_in(x)]  # 初始卷积for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1], temb)  # 残差块(无时间步)if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:# 支持音频专用下采样(时间步长4)hs.append(self.down[i_level].downsample(hs[-1]))# 中间层:处理最深层特征h = hs[-1]h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# 输出潜在变量参数h = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)  # 输出:2×z_channels(均值+对数方差)或z_channelsreturn h
2.3 核心作用
  • 数据压缩:通过多层下采样将高分辨率输入(如64×64频谱)压缩为低分辨率潜在变量(如8×8),减少数据维度。
  • 概率建模:输出潜在变量的分布参数(均值+对数方差),为VAE的“变分”特性提供基础(通过KL散度约束潜在分布接近标准正态分布)。

3. 解码器(Decoder

解码器是编码器的逆过程,将潜在变量重建为原始数据(如从VAE的潜在变量重建音频频谱)。结构与Model的上采样部分类似,同样无时间步嵌入。

3.1 初始化参数(__init__

Encoder参数基础上新增:

def __init__(self,...,give_pre_end=False,  # 是否返回最终卷积前的特征tanh_out=False,  # 是否用tanh约束输出范围(如[-1,1])...
):
  • tanh_out=True用于约束输出在[-1,1]范围内,适合输入数据已归一化的场景(如音频频谱通常归一化到该范围)。
3.2 前向传播(forward
def forward(self, z):temb = None  # 无时间步# 潜在变量映射到高层特征h = self.conv_in(z)  # 将潜在变量通道转为解码器基础通道# 中间层:处理高层特征h = self.mid.block_1(h, temb)h = self.mid.attn_1(h)h = self.mid.block_2(h, temb)# 上采样:逐步恢复分辨率(与编码器下采样对称)for i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](h, temb)  # 无跳连(与Encoder对称,无需融合)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:# 支持音频专用上采样(时间×4,频率×2)h = self.up[i_level].upsample(h)# 输出重建结果if self.give_pre_end:return h  # 返回最终卷积前的特征(用于中间任务)h = self.norm_out(h)h = nonlinearity(h)h = self.conv_out(h)  # 输出通道=out_ch(与原始数据一致)if self.tanh_out:h = torch.tanh(h)  # 约束范围return h
3.3 核心作用
  • 重建数据:从潜在变量恢复原始数据的分辨率和细节,与编码器形成“编码-解码”闭环(如VAE的重建损失基于解码器输出与输入的差异)。
  • 对称性设计:与编码器的下采样阶段严格对称(分辨率变化、通道数变化一致),确保潜在变量能被准确解码。

核心网络的协同关系

这三个网络并非孤立存在,而是通过“组合”实现复杂生成任务:

  1. VAE模式Encoder编码输入为潜在分布→采样潜在变量→Decoder重建输入,通过重建损失+KL散度训练,学习数据的压缩表示。
  2. 扩散+VAE模式:先用Encoder将数据映射到低维潜在空间→在潜在空间用Model(扩散模型)学习生成→再用Decoder将生成的潜在变量解码为最终数据(降低扩散模型的计算成本)。

这种组合兼顾了VAE的压缩能力和扩散模型的生成质量,是现代生成式模型(如Stable Diffusion、AudioLDM)的主流架构。

总结

核心网络结构围绕“生成”与“重建”设计:

  • Model作为扩散模型的降噪核心,通过U-Net结构和时间步嵌入学习从噪声中恢复数据的规律;
  • EncoderDecoder构成自编码器,实现数据与潜在空间的双向映射,支持压缩、重建和作为扩散模型的前置处理。

三者均基于残差块、注意力和采样层构建,兼顾深层特征提取、长距离依赖捕捉和多尺度信息融合,为音频/图像等时空数据的生成任务提供了强大的基础架构。

三、辅助模块

辅助模块是对核心网络的补充与扩展,主要解决特定场景需求,如分辨率调整、轻量级重建、预训练模型融合等,增强了整个框架的灵活性和适用性(尤其针对音频生成的特殊需求)。以下是详细解析:

1. 简化解码器(SimpleDecoder

轻量化解码器,用于快速重建或低复杂度场景,通过少量残差块和上采样实现基础的特征恢复。

1.1 结构与前向传播
class SimpleDecoder(nn.Module):def __init__(self, in_channels, out_channels, *args, **kwargs):super().__init__()# 模块列表:1x1卷积→残差块×3→1x1卷积→上采样self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),  # 通道调整ResnetBlock(in_channels=in_channels, out_channels=2*in_channels, temb_channels=0, dropout=0.0),ResnetBlock(in_channels=2*in_channels, out_channels=4*in_channels, temb_channels=0, dropout=0.0),ResnetBlock(in_channels=4*in_channels, out_channels=2*in_channels, temb_channels=0, dropout=0.0),nn.Conv2d(2*in_channels, in_channels, 1),  # 通道压缩Upsample(in_channels, with_conv=True)  # 上采样(尺度×2)])# 输出层self.norm_out = Normalize(in_channels)self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):for i, layer in enumerate(self.model):if i in [1, 2, 3]:  # 残差块需要传入temb(此处为None,非扩散过程)x = layer(x, None)else:x = layer(x)# 最终处理h = self.norm_out(x)h = nonlinearity(h)return self.conv_out(h)
1.2 核心作用
  • 轻量重建:相比完整Decoder,用更少的残差块(3个)实现基础的特征解码,适合资源受限场景或作为辅助解码器。
  • 固定流程:结构固定(通道先扩张后压缩+上采样),无需复杂参数配置,快速适配简单重建任务(如音频频谱的粗略恢复)。

2. 上采样解码器(UpsampleDecoder

专注于分辨率提升的解码器,通过多阶段残差块和上采样,逐步将低分辨率特征恢复为高分辨率输出。

2.1 结构与前向传播
class UpsampleDecoder(nn.Module):def __init__(self,in_channels,out_channels,ch,  # 基础通道数num_res_blocks,  # 每个阶段的残差块数量resolution,  # 目标分辨率ch_mult=(2, 2),  # 通道倍增因子dropout=0.0,):super().__init__()self.temb_ch = 0  # 无时间步self.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksblock_in = in_channels  # 输入通道curr_res = resolution // (2 ** (self.num_resolutions - 1))  # 初始分辨率(逐步上采样至目标)# 残差块组(每个阶段一组)和上采样层self.res_blocks = nn.ModuleList()self.upsample_blocks = nn.ModuleList()for i_level in range(self.num_resolutions):res_block = []block_out = ch * ch_mult[i_level]  # 当前阶段输出通道for i_block in range(self.num_res_blocks + 1):res_block.append(ResnetBlock(in_channels=block_in,out_channels=block_out,temb_channels=self.temb_ch,dropout=dropout))block_in = block_outself.res_blocks.append(nn.ModuleList(res_block))# 除最后一个阶段外,添加上采样层if i_level != self.num_resolutions - 1:self.upsample_blocks.append(Upsample(block_in, True))curr_res *= 2  # 分辨率翻倍# 输出层self.norm_out = Normalize(block_in)self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)def forward(self, x):h = x# 逐阶段处理:残差块→上采样(最后阶段不上采样)for k, i_level in enumerate(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.res_blocks[i_level][i_block](h, None)  # 残差块处理if i_level != self.num_resolutions - 1:h = self.upsample_blocks[k](h)  # 上采样# 输出处理h = self.norm_out(h)h = nonlinearity(h)return self.conv_out(h)
2.2 核心作用

-** 渐进式上采样 :通过多阶段残差块和上采样,逐步将低分辨率特征(如8×8)恢复为目标分辨率(如32×32),避免单次大尺度上采样导致的细节丢失。
-
灵活配置 **:通过ch_multnum_res_blocks控制网络深度和通道扩张,适配不同分辨率提升需求(如音频频谱的时间维度拉长)。

3. 潜在空间尺度调整器(LatentRescaler

调整潜在变量的分辨率(缩放尺寸),同时通过残差块和注意力保持特征一致性,用于适配不同网络间的潜在空间尺度差异。

3.1 结构与前向传播
class LatentRescaler(nn.Module):def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):super().__init__()self.factor = factor  # 缩放因子(如1.5倍)# 输入卷积:调整通道至中间维度self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)# 残差块组1:缩放前增强特征self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,) for _ in range(depth)])self.attn = AttnBlock(mid_channels)  # 注意力:捕捉长距离依赖# 残差块组2:缩放后修复特征self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,out_channels=mid_channels,temb_channels=0,dropout=0.0,) for _ in range(depth)])# 输出卷积:调整通道至目标维度self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1)def forward(self, x):x = self.conv_in(x)  # 通道调整# 缩放前特征增强for block in self.res_block1:x = block(x, None)# 尺度调整(插值缩放)x = torch.nn.functional.interpolate(x,size=(int(round(x.shape[2] * self.factor)),  # 高度×缩放因子int(round(x.shape[3] * self.factor))   # 宽度×缩放因子),)x = self.attn(x).contiguous()  # 注意力修复缩放后的特征一致性# 缩放后特征增强for block in self.res_block2:x = block(x, None)return self.conv_out(x)  # 输出目标通道
3.2 核心作用

-** 跨尺度适配 :解决不同网络输出的潜在变量分辨率不匹配问题(如将VAE的16×16潜在变量缩放到扩散模型所需的24×24)。
-
特征保持 **:缩放前后通过残差块增强特征,中间用注意力修复缩放导致的空间一致性丢失,确保调整后的潜在变量仍保留关键信息。

4. 融合尺度调整的编码器/解码器

将基础编码器/解码器与LatentRescaler结合,实现“编码+尺度调整”或“尺度调整+解码”的端到端流程,简化多阶段生成 pipeline。

4.1 MergedRescaleEncoder(编码+缩放)
class MergedRescaleEncoder(nn.Module):def __init__(self,in_channels,ch,resolution,out_ch,  # 最终输出通道num_res_blocks,attn_resolutions,dropout=0.0,resamp_with_conv=True,ch_mult=(1, 2, 4, 8),rescale_factor=1.0,  # 缩放因子rescale_module_depth=1,  # 残差块深度):super().__init__()intermediate_chn = ch * ch_mult[-1]  # 编码器输出通道# 基础编码器self.encoder = Encoder(in_channels=in_channels,num_res_blocks=num_res_blocks,ch=ch,ch_mult=ch_mult,z_channels=intermediate_chn,double_z=False,  # 不输出均值+方差,仅输出特征resolution=resolution,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,out_ch=None,)# 编码后缩放self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=intermediate_chn,mid_channels=intermediate_chn,out_channels=out_ch,depth=rescale_module_depth,)def forward(self, x):x = self.encoder(x)  # 先编码x = self.rescaler(x)  # 再缩放return x
4.2 MergedRescaleDecoder(缩放+解码)
class MergedRescaleDecoder(nn.Module):def __init__(self,z_channels,  # 输入潜在变量通道out_ch,  # 输出通道resolution,num_res_blocks,attn_resolutions,ch,ch_mult=(1, 2, 4, 8),dropout=0.0,resamp_with_conv=True,rescale_factor=1.0,rescale_module_depth=1,):super().__init__()tmp_chn = z_channels * ch_mult[-1]  # 中间通道# 基础解码器self.decoder = Decoder(out_ch=out_ch,z_channels=tmp_chn,attn_resolutions=attn_resolutions,dropout=dropout,resamp_with_conv=resamp_with_conv,in_channels=None,num_res_blocks=num_res_blocks,ch_mult=ch_mult,resolution=resolution,ch=ch,)# 解码前缩放self.rescaler = LatentRescaler(factor=rescale_factor,in_channels=z_channels,mid_channels=tmp_chn,out_channels=tmp_chn,depth=rescale_module_depth,)def forward(self, x):x = self.rescaler(x)  # 先缩放x = self.decoder(x)  # 再解码return x
4.3 核心作用
  • 流程简化:将“编码→缩放”或“缩放→解码”合并为单一模块,减少多阶段模型的调用复杂度(如VAE编码后直接适配扩散模型的潜在空间)。
  • 参数联动:确保编码器/解码器与缩放器的通道、分辨率参数匹配,避免手动调整的错误。

5. 通用上采样器(Upsampler

针对特定分辨率提升需求设计的专用上采样器,结合LatentRescalerDecoder,实现从低分辨率到高分辨率的端到端生成。

class Upsampler(nn.Module):def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):super().__init__()assert out_size >= in_size  # 确保输出分辨率不小于输入# 计算上采样阶段数(基于2的对数)num_blocks = int(np.log2(out_size // in_size)) + 1# 初始缩放因子(处理非2的幂次倍数)factor_up = 1.0 + (out_size % in_size)# 先缩放再解码self.rescaler = LatentRescaler(factor=factor_up,in_channels=in_channels,mid_channels=2 * in_channels,out_channels=in_channels,)self.decoder = Decoder(out_ch=out_channels,resolution=out_size,z_channels=in_channels,num_res_blocks=2,attn_resolutions=[],in_channels=None,ch=in_channels,ch_mult=[ch_mult for _ in range(num_blocks)],  # 每个阶段通道倍增)def forward(self, x):x = self.rescaler(x)  # 初步缩放x = self.decoder(x)  # 解码至目标分辨率return x
5.1 核心作用
  • 目标导向:直接指定输入(in_size)和输出(out_size)分辨率,自动计算所需上采样阶段和因子,无需手动配置。
  • 专用优化:针对分辨率提升场景(如音频从32kHz升至44.1kHz)优化,平衡速度与质量。

6. 分辨率调整器(Resize

轻量级分辨率调整工具,支持学习式或固定插值方式调整特征图尺度,用于简单的分辨率适配。

class Resize(nn.Module):def __init__(self, in_channels=None, learned=False, mode="bilinear"):super().__init__()self.with_conv = learned  # 是否用学习式卷积调整self.mode = mode  # 插值方式(双线性/近邻等)if self.with_conv:# 学习式下采样(未实现,预留接口)raise NotImplementedError()self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)def forward(self, x, scale_factor=1.0):if scale_factor == 1.0:return x  # 无需调整else:# 固定插值调整(非学习式)x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)return x
6.1 核心作用
  • 轻量适配:用于无需复杂特征修复的场景(如中间特征的临时分辨率调整),计算成本低。
  • 灵活模式:支持双线性、近邻等多种插值方式,适配不同数据的平滑度需求(如音频频谱用近邻保留锐度)。

7. 预训练模型后处理器(FirstStagePostProcessor

将预训练模型(如VAE)的输出映射到目标网络(如扩散模型)的潜在空间,实现多模型协同工作(如“预训练VAE+扩散模型”的两阶段生成)。

7.1 结构与前向传播
class FirstStagePostProcessor(nn.Module):def __init__(self,ch_mult: list,  # 通道倍增因子in_channels,pretrained_model: nn.Module = None,  # 预训练模型(如VAE)reshape=False,  # 是否重塑特征为序列格式n_channels=None,  # 中间通道数dropout=0.0,pretrained_config=None,  # 预训练模型配置(用于实例化)):super().__init__()# 加载预训练模型(冻结参数,不参与训练)if pretrained_config is not None:self.instantiate_pretrained(pretrained_config)else:self.pretrained_model = pretrained_modelself.do_reshape = reshape# 通道投影:将预训练模型输出映射到中间通道if n_channels is None:n_channels = self.pretrained_model.encoder.chself.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, stride=1, padding=1)# 特征处理:残差块+下采样,适配目标网络的潜在空间blocks = []downs = []ch_in = n_channelsfor m in ch_mult:blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout))ch_in = m * n_channelsdowns.append(Downsample(ch_in, with_conv=False))self.model = nn.ModuleList(blocks)self.downsampler = nn.ModuleList(downs)def instantiate_pretrained(self, config):# 从配置实例化预训练模型并冻结self.pretrained_model = instantiate_from_config(config).eval()for param in self.pretrained_model.parameters():param.requires_grad = False@torch.no_grad()  # 预训练模型不计算梯度def encode_with_pretrained(self, x):# 用预训练模型编码(如VAE的编码器)c = self.pretrained_model.encode(x)if isinstance(c, DiagonalGaussianDistribution):  # 若输出为分布,取众数c = c.mode()return cdef forward(self, x):# 步骤1:预训练模型编码z_fs = self.encode_with_pretrained(x)# 步骤2:通道投影与增强z = self.proj_norm(z_fs)z = self.proj(z)z = nonlinearity(z)# 步骤3:特征处理(残差块+下采样)for submodel, downmodel in zip(self.model, self.downsampler):z = submodel(z, temb=None)z = downmodel(z)# 可选:重塑为序列格式(如用于Transformer)if self.do_reshape:z = rearrange(z, "b c h w -> b (h w) c")return z
7.2 核心作用
  • 模型桥接:解决预训练模型与下游模型的潜在空间不兼容问题(如将VAE的潜在变量转换为扩散模型可处理的格式)。
  • 特征适配:通过投影、残差块和下采样,调整预训练模型输出的通道和分辨率,使其符合下游模型的输入要求。
  • 冻结预训练:预训练模型参数固定(requires_grad=False),仅训练适配层,避免破坏预训练知识。

总结

辅助模块通过以下方式增强了框架的实用性:

  1. 灵活性:提供轻量级解码器(SimpleDecoder)、通用上采样器(Upsampler)等,适配不同复杂度需求。
  2. 兼容性:通过LatentRescalerMergedRescaleEncoder/Decoder解决不同网络间的尺度/通道不匹配问题。
  3. 扩展性FirstStagePostProcessor支持融合预训练模型,实现“预训练+微调”的两阶段生成,降低训练成本。

这些模块与核心网络(ModelEncoderDecoder)协同,形成了一套完整的生成式模型工具链,尤其适合音频等时空数据的生成、重建和尺度调整任务。

总结

这段代码构建了一套完整的生成式模型工具链,核心包括:

  1. 基础组件:残差块、注意力、上/下采样、时间步嵌入,适配时空数据(音频/图像)。
  2. 核心网络:扩散模型主网络(Model)、VAE风格的编码器(Encoder)和解码器(Decoder)。
  3. 辅助工具:尺度调整、预训练模型融合等模块,支持复杂生成流程。

整体设计聚焦于灵活性和音频处理特性(如时间-频率维度的差异化采样),可用于音频生成、降噪、重建等任务,是AudioLDM等音频生成模型的核心组件。

autoencoder.py

先给出完整代码:

import torch
from audioldm.latent_diffusion.ema import *
from audioldm.variational_autoencoder.modules import Encoder, Decoder
from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistributionfrom audioldm.hifigan.utilities import get_vocoder, vocoder_inferclass AutoencoderKL(nn.Module):def __init__(self,ddconfig=None,lossconfig=None,image_key="fbank",embed_dim=None,time_shuffle=1,subband=1,ckpt_path=None,reload_from_ckpt=None,ignore_keys=[],colorize_nlabels=None,monitor=None,base_learning_rate=1e-5,):super().__init__()self.encoder = Encoder(**ddconfig)self.decoder = Decoder(**ddconfig)self.subband = int(subband)if self.subband > 1:print("Use subband decomposition %s" % self.subband)self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)self.vocoder = get_vocoder(None, "cpu")self.embed_dim = embed_dimif monitor is not None:self.monitor = monitorself.time_shuffle = time_shuffleself.reload_from_ckpt = reload_from_ckptself.reloaded = Falseself.mean, self.std = None, Nonedef encode(self, x):# x = self.time_shuffle_operation(x)x = self.freq_split_subband(x)h = self.encoder(x)moments = self.quant_conv(h)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z):z = self.post_quant_conv(z)dec = self.decoder(z)dec = self.freq_merge_subband(dec)return decdef decode_to_waveform(self, dec):dec = dec.squeeze(1).permute(0, 2, 1)wav_reconstruction = vocoder_infer(dec, self.vocoder)return wav_reconstructiondef forward(self, input, sample_posterior=True):posterior = self.encode(input)if sample_posterior:z = posterior.sample()else:z = posterior.mode()if self.flag_first_run:print("Latent size: ", z.size())self.flag_first_run = Falsedec = self.decode(z)return dec, posteriordef freq_split_subband(self, fbank):if self.subband == 1 or self.image_key != "stft":return fbankbs, ch, tstep, fbins = fbank.size()assert fbank.size(-1) % self.subband == 0assert ch == 1return (fbank.squeeze(1).reshape(bs, tstep, self.subband, fbins // self.subband).permute(0, 2, 1, 3))def freq_merge_subband(self, subband_fbank):if self.subband == 1 or self.image_key != "stft":return subband_fbankassert subband_fbank.size(1) == self.subband  # Channel dimensionbs, sub_ch, tstep, fbins = subband_fbank.size()return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)

这段代码实现了一个基于PyTorch的AutoencoderKL(KL散度正则化的自编码器),主要用于音频信号处理(从代码中的audioldm相关依赖可以看出)。该模型结合了变分自编码器(VAE)的思想,包含编码器、解码器、潜在变量量化等组件,并集成了语音合成相关的功能。以下是详细解释:

1. 类与依赖说明

  • 基类AutoencoderKL继承自torch.nn.Module,是PyTorch中所有神经网络模块的基类。
  • 依赖组件
    • audioldm导入的编码器(Encoder)、解码器(Decoder):用于音频特征的编码和解码。
    • 对角高斯分布(DiagonalGaussianDistribution):用于建模潜在变量的概率分布(VAE的核心)。
    • 语音解码器(vocoder):用于将模型输出的频谱特征转换为可听的音频波形。

2. 初始化方法(__init__

AutoencoderKL类的__init__方法是模型的构造函数,用于初始化模型的核心组件、超参数和配置信息。它定义了模型的基本结构和依赖关系,为后续的编码、解码等操作奠定基础。以下是对该方法的逐部分详细解析:

方法定义与参数说明

方法定义如下:

def __init__(self,ddconfig=None,lossconfig=None,image_key="fbank",embed_dim=None,time_shuffle=1,subband=1,ckpt_path=None,reload_from_ckpt=None,ignore_keys=[],colorize_nlabels=None,monitor=None,base_learning_rate=1e-5,
):

参数含义:

  • ddconfig:编码器(Encoder)和解码器(Decoder)的配置字典(如网络层数、通道数等核心参数)。
  • lossconfig:损失函数的配置字典(当前代码未直接使用,预留用于定义损失计算方式)。
  • image_key:输入特征的键名(如"fbank"表示梅尔频谱,"stft"表示短时傅里叶变换频谱),用于区分不同类型的输入特征。
  • embed_dim:潜在变量(latent variable)的维度(潜在空间的维度)。
  • time_shuffle:时间维度的打乱参数(当前代码未启用实际操作,预留用于数据增强)。
  • subband:频谱的子带分割数量(用于将频谱沿频率轴分割为多个子带,提升计算效率)。
  • ckpt_path/reload_from_ckpt:模型权重检查点路径(用于加载预训练权重)。
  • ignore_keys:加载检查点时需要忽略的参数名列表(避免因参数不匹配导致的错误)。
  • colorize_nlabels:预留参数(可能用于多类别特征的颜色映射,音频任务中较少用到)。
  • monitor:训练时需要监控的指标(如重建损失),用于日志记录或早停策略。
  • base_learning_rate:基础学习率(用于训练时的优化器配置)。

初始化核心逻辑

1. 调用父类构造函数
super().__init__()

这是PyTorch中定义神经网络模块的标准操作,用于初始化父类nn.Module的内部状态(如参数管理、设备配置等),确保模型能正常使用PyTorch的自动求导、参数保存等功能。

2. 初始化编码器与解码器
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(** ddconfig)
  • EncoderDecoder是从audioldm.variational_autoencoder.modules导入的网络模块,专门用于音频特征的编码和解码。
  • **ddconfig表示将ddconfig字典中的键值对作为关键字参数传递给EncoderDecoder的构造函数。例如,ddconfig可能包含z_channels(潜在特征通道数)、channels(网络层通道数)等参数,用于定义编码器和解码器的网络结构(如卷积层数量、每层的输出通道等)。
  • 作用:编码器负责将输入音频特征(如频谱)压缩为高维特征;解码器负责将潜在变量还原为重建的音频特征。
3. 配置子带分解
self.subband = int(subband)
if self.subband > 1:print("Use subband decomposition %s" % self.subband)
  • subband参数用于控制是否对频谱进行子带分割(仅当输入特征为STFT频谱时生效)。例如,若subband=4,则将频谱的频率轴分为4个子带,每个子带处理一部分频率范围。
  • 作用:子带分解可降低单个子带的频率维度,减少计算量,同时适应音频信号在不同频率范围内的特性差异。
4. 定义潜在变量的维度转换卷积层
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

这两个1x1卷积层是连接编码器、潜在空间和解码器的关键组件:

  • quant_conv:输入通道数为2 * ddconfig["z_channels"](编码器输出的特征通道数的2倍,对应均值和方差两个参数),输出通道数为2 * embed_dim(潜在分布的均值和方差的总维度)。作用是将编码器输出的高维特征转换为潜在变量的概率分布参数(均值和方差)。
  • post_quant_conv:输入通道数为embed_dim(潜在变量的维度),输出通道数为ddconfig["z_channels"](解码器所需的输入通道数)。作用是将潜在变量的维度调整为解码器可接受的输入维度。
  • 1x1卷积的特点:不改变特征的空间维度(仅调整通道数),高效完成维度转换。
5. 初始化声码器(Vocoder)
self.vocoder = get_vocoder(None, "cpu")
  • get_vocoder是从audioldm.hifigan.utilities导入的工具函数,用于获取预训练的声码器(如HiFi-GAN)。声码器的作用是将模型输出的频谱特征(如STFT、梅尔频谱)转换为可听的音频波形。
  • 初始设备设为"cpu",后续可根据需要迁移到GPU。
6. 存储核心属性
self.embed_dim = embed_dim  # 潜在变量的维度
if monitor is not None:self.monitor = monitor  # 训练时监控的指标
self.time_shuffle = time_shuffle  # 时间打乱参数(预留)
self.reload_from_ckpt = reload_from_ckpt  # 检查点路径(用于后续加载)
self.reloaded = False  # 标记是否已加载检查点(初始为未加载)
self.mean, self.std = None, None  # 预留用于存储数据的均值和标准差(可能用于归一化)

这些属性用于记录模型的配置信息、状态或预留功能(如数据归一化的均值/标准差、检查点加载状态等)。

总结

__init__方法的核心作用是“搭骨架”:通过初始化编码器、解码器、维度转换卷积层、声码器等组件,定义了AutoencoderKL的基本结构。它将输入的配置参数(如ddconfigembed_dim)映射为模型的可训练组件和状态变量,为后续的encode(编码)、decode(解码)等方法提供了基础。

特别地,该方法针对音频处理场景做了优化:

  • 支持子带分解以提升效率;
  • 集成声码器以直接生成音频波形;
  • 通过潜在变量的概率分布建模(后续在encode中实现),体现了变分自编码器(VAE)的核心思想。

3. 核心方法解析

AutoencoderKL的核心方法实现了模型的核心功能流程:从输入音频特征的编码(得到潜在分布)、潜在变量的采样、解码(重建特征),到最终转换为可听波形,以及辅助的子带处理。以下是对这些核心方法的逐句解析:

3.1 编码方法(encode):将输入特征映射为潜在分布

def encode(self, x):# x = self.time_shuffle_operation(x)  # 注释:预留的时间维度打乱操作(未启用)x = self.freq_split_subband(x)  # 对输入特征进行子带分割(可选)h = self.encoder(x)  # 编码器输出高维特征moments = self.quant_conv(h)  # 转换为潜在分布的参数(均值和方差)posterior = DiagonalGaussianDistribution(moments)  # 构建对角高斯分布return posterior  # 返回后验分布
详细解析:
  • 输入x:通常是音频的频谱特征(如梅尔频谱fbank或STFT频谱),形状一般为(batch_size, channels, time_steps, freq_bins)(批量大小、通道数、时间步、频率 bins)。
  • 子带分割(self.freq_split_subband(x)
    subband > 1且输入为STFT频谱(image_key="stft"),则将频谱沿频率轴分割为多个子带(见3.5节),降低单个子带的频率维度,减少计算量。若不满足条件,则直接返回x
  • 编码器处理(self.encoder(x)
    编码器(Encoder)是一个深度神经网络(通常由卷积层构成),将分割后的子带特征x压缩为更高维的特征h(形状与ddconfig配置相关,如(batch_size, z_channels, ...))。
  • 分布参数转换(self.quant_conv(h)
    quant_conv是1x1卷积层,将编码器输出的h(通道数为2*z_channels,因为VAE中通常需要输出均值和方差两组参数)转换为潜在分布的参数moments,通道数变为2*embed_dimembed_dim是潜在变量维度,两组参数分别对应均值和方差)。
  • 构建后验分布(DiagonalGaussianDistribution(moments)
    moments被拆分为均值(mean)和对数方差(logvar),构建一个对角高斯分布(各维度独立的高斯分布)。这是VAE的核心:编码器不直接输出潜在变量,而是输出其概率分布,体现“变分”特性。
  • 返回值:后验分布posterior,后续可从该分布中采样得到潜在变量z

3.2 解码方法(decode):将潜在变量重建为特征

def decode(self, z):z = self.post_quant_conv(z)  # 调整潜在变量维度以匹配解码器输入dec = self.decoder(z)  # 解码器输出重建特征dec = self.freq_merge_subband(dec)  # 合并子带(与编码时的分割对应)return dec  # 返回重建特征
详细解析:
  • 输入z:从潜在分布中采样的潜在变量,形状为(batch_size, embed_dim, time_steps, freq_bins_sub)(子带分割后的频率维度)。
  • 维度调整(self.post_quant_conv(z)
    post_quant_conv是1x1卷积层,将潜在变量z的通道数从embed_dim转换为解码器所需的z_channels(与编码器的输出通道数匹配),确保维度兼容。
  • 解码器处理(self.decoder(z)
    解码器(Decoder)是编码器的逆过程,通过卷积层(通常包含转置卷积实现上采样)将调整后的潜在变量z还原为子带分割后的重建特征dec(形状与编码时的子带特征对应)。
  • 子带合并(self.freq_merge_subband(dec)
    若编码时进行了子带分割,此处需将子带特征合并回原始频谱维度(见3.5节),确保输出特征的形状与输入x一致。
  • 返回值:重建的音频特征dec,形状与输入x相同(如(batch_size, channels, time_steps, freq_bins))。

3.3 波形转换方法(decode_to_waveform):将重建特征转为可听音频

def decode_to_waveform(self, dec):dec = dec.squeeze(1).permute(0, 2, 1)  # 调整特征维度以匹配声码器输入wav_reconstruction = vocoder_infer(dec, self.vocoder)  # 声码器生成波形return wav_reconstruction
详细解析:
  • 输入dec:解码后的重建特征(如STFT或梅尔频谱),形状为(batch_size, 1, time_steps, freq_bins)(通常单通道)。
  • 维度调整
    • squeeze(1):移除通道维度(因单通道无意义),形状变为(batch_size, time_steps, freq_bins)
    • permute(0, 2, 1):交换时间和频率维度,形状变为(batch_size, freq_bins, time_steps),以匹配声码器的输入格式(声码器通常要求(batch, freq, time))。
  • 声码器生成波形(vocoder_infer
    声码器(如HiFi-GAN)是专门将频谱特征转换为音频波形的模型。vocoder_infer函数调用预训练的self.vocoder,将调整后的dec转换为波形信号wav_reconstruction,形状为(batch_size, 1, sample_length)(单声道音频,采样点数为sample_length)。
  • 作用:频谱特征是“视觉”特征(不可直接听),通过该方法将其转换为可听的音频波形,是音频生成任务的最终输出环节。

3.4 前向传播方法(forward):完整的编码-采样-解码流程

def forward(self, input, sample_posterior=True):posterior = self.encode(input)  # 编码得到后验分布# 从后验分布中采样(训练时)或取均值(推理时)if sample_posterior:z = posterior.sample()else:z = posterior.mode()# 首次运行时打印潜在变量尺寸(仅执行一次)if self.flag_first_run:print("Latent size: ", z.size())self.flag_first_run = Falsedec = self.decode(z)  # 解码得到重建特征return dec, posterior  # 返回重建结果和后验分布
详细解析:
  • 输入参数
    • input:原始音频特征(与encode方法的输入x一致)。
    • sample_posterior:布尔值,控制是否从后验分布中采样(True)或直接取均值(False)。
  • 核心流程
    1. 编码:调用encode得到后验分布posterior
    2. 采样/取均值
      • 训练时(sample_posterior=True):从posterior中随机采样z,引入随机性,确保潜在空间的连续性(符合VAE的“变分”约束)。
      • 推理时(sample_posterior=False):直接取分布的均值(posterior.mode()),避免随机性导致的输出不稳定。
    3. 解码:调用decodez重建为特征dec
    4. 首次运行打印:通过self.flag_first_run(需提前定义为True)打印潜在变量z的尺寸,方便调试网络结构。
  • 返回值
    • dec:重建的音频特征(用于计算重建损失,衡量与输入input的差异)。
    • posterior:后验分布(用于计算KL散度损失,约束潜在分布接近先验分布,通常是标准高斯分布)。

3.5 子带处理方法:频谱分割与合并

子带处理是针对STFT频谱的优化,通过分割频率轴降低计算量,仅当subband > 1image_key="stft"时生效。

3.5.1 频谱分割(freq_split_subband
def freq_split_subband(self, fbank):# 若不满足子带分割条件,直接返回原始特征if self.subband == 1 or self.image_key != "stft":return fbank# 获取输入特征的尺寸:(batch_size, channels, time_steps, freq_bins)bs, ch, tstep, fbins = fbank.size()# 校验:频率 bins 必须能被子带数整除,且输入必须是单通道assert fbank.size(-1) % self.subband == 0assert ch == 1# 分割逻辑:将频率轴拆分为subband个子带return (fbank.squeeze(1)  # 移除单通道维度:(bs, tstep, fbins).reshape(bs, tstep, self.subband, fbins // self.subband)  # 拆分频率轴:(bs, tstep, subband, fbins/subband).permute(0, 2, 1, 3)  # 调整维度顺序:(bs, subband, tstep, fbins/subband)(子带作为新的通道维度))
  • 作用:将原始频谱的频率轴(fbins)均匀分割为subband个子带,每个子带的频率维度为fbins/subband,并将子带作为新的通道维度,便于并行处理。
3.5.2 频谱合并(freq_merge_subband
def freq_merge_subband(self, subband_fbank):# 若不满足子带合并条件,直接返回子带特征if self.subband == 1 or self.image_key != "stft":return subband_fbank# 校验:子带特征的通道数必须等于子带数assert subband_fbank.size(1) == self.subband  # Channel dimension# 获取子带特征的尺寸:(batch_size, subband, time_steps, fbins_sub)bs, sub_ch, tstep, fbins = subband_fbank.size()# 合并逻辑:将子带还原为原始频率轴return subband_fbank.permute(0, 2, 1, 3)  # 调整维度顺序:(bs, tstep, subband, fbins_sub).reshape(bs, tstep, -1)  # 合并子带:(bs, tstep, subband*fbins_sub) = (bs, tstep, fbins).unsqueeze(1)  # 恢复通道维度:(bs, 1, tstep, fbins)
  • 作用:与freq_split_subband互逆,将子带特征重新合并为原始频谱维度,确保解码输出与输入特征的形状一致。

核心方法的协同关系

这些方法共同构成了AutoencoderKL的完整功能链:
输入特征 → encode(子带分割→编码→分布建模)→ 采样潜在变量 → decode(解码→子带合并)→ 重建特征 → decode_to_waveform(转为音频波形)

其中,forward是对外接口,串联了编码-采样-解码流程;子带处理是可选的效率优化;decode_to_waveform则是音频任务特有的“最后一公里”转换,将抽象特征转为可感知的音频。这种设计既遵循了VAE的概率建模思想,又针对音频信号的特性做了专门优化。

4. 模型核心思想

这是一个变分自编码器(VAE) 的变种,核心特点是:

  1. 潜在变量的概率建模:编码器输出的是潜在变量的分布(而非确定值),通过采样引入随机性,增强模型的泛化能力。
  2. KL散度正则化:训练时会通过后验分布(posterior)与先验分布(通常是标准高斯分布)的KL散度,约束潜在空间的分布特性。
  3. 音频专用优化
    • 子带处理:降低频谱维度,适应音频信号的频率特性。
    • 集成vocoder:直接支持从特征到波形的转换,方便音频生成任务。

5. 应用场景

该模型可能用于音频生成、音频降噪、声音转换等任务。例如:

  • 训练时,模型学习将输入音频(如语音、音乐)编码为潜在变量,再重建输入,通过最小化重建损失+KL损失优化。
  • 推理时,可从先验分布中采样潜在变量,解码得到新的音频(生成任务);或对噪声音频编码后解码,实现降噪。

总结:这段代码实现了一个面向音频处理的KL正则化自编码器,结合了VAE的概率建模和音频专用的子带处理、波形转换功能,是音频生成模型(如AudioLDM)的核心组件之一。

http://www.xdnf.cn/news/1410283.html

相关文章:

  • LeetCode100-160相交链表【链表介绍】
  • 基于AI的大模型在S2B2C商城小程序中的应用与定价策略自我评估
  • USBX移植(X是eXtended的意思)
  • 【python]变量及简单数据类型
  • Spring Data JPA 派生查询方法命名速查表
  • 平滑滤波器(Smooth Filter)的MATLAB与Verilog仿真设计与实现
  • linux内核trace_begin和trace_end使用分析
  • ICode总线原理
  • 【Bluedroid】A2DP Source 音频传输停止流程及资源管理机制(btif_a2dp_source_stop_audio_req)
  • ESP32学习笔记_Peripherals(5)——SPI主机通信
  • 编写一个名为 tfgets 的 fgets 函数版本
  • FPGA入门指南:从零开始的可编程逻辑世界探索
  • deep seek的对话记录如何导出
  • 【大数据技术实战】流式计算 Flink~生产错误实战解析
  • Springcloud-----Nacos
  • 【Spring Cloud微服务】7.拆解分布式事务与CAP理论:从理论到实践,打造数据一致性堡垒
  • Java试题-选择题(25)
  • 【Java进阶】Java与SpringBoot线程池深度优化指南
  • 【计算机组成原理·信息】2数据②
  • SpringAI应用开发面试全流程:核心技术、工程架构与业务场景深度解析
  • 第2.5节:中文大模型(文心一言、通义千问、讯飞星火)
  • 【系统分析师】高分论文:论网络系统的安全设计
  • 【51单片机】【protues仿真】基于51单片机音乐喷泉系统
  • Mysql什么时候建临时表
  • MySQL直接启动命令mysqld详解:从参数说明到故障排查
  • 策略模式:灵活应对算法动态切换
  • 探索数据结构中的 “树”:揭开层次关系的奥秘
  • 3【鸿蒙/OpenHarmony/NDK】如何在鸿蒙应用中使用NDK?
  • Makefile语句解析:头文件目录自动发现与包含标志生成
  • 【读论文】自监督消除高光谱成像中的非独立噪声