HiFi-GAN模型代码分析
先给出完整的代码:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils import init_weights, get_paddingLRELU_SLOPE = 0.1class ResBlock1(torch.nn.Module):def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):super(ResBlock1, self).__init__()self.h = hself.convs1 = nn.ModuleList([weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],padding=get_padding(kernel_size, dilation[0]))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],padding=get_padding(kernel_size, dilation[1]))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],padding=get_padding(kernel_size, dilation[2])))])self.convs1.apply(init_weights)self.convs2 = nn.ModuleList([weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1)))])self.convs2.apply(init_weights)def forward(self, x):for c1, c2 in zip(self.convs1, self.convs2):xt = F.leaky_relu(x, LRELU_SLOPE)xt = c1(xt)xt = F.leaky_relu(xt, LRELU_SLOPE)xt = c2(xt)x = xt + xreturn xdef remove_weight_norm(self):for l in self.convs1:remove_weight_norm(l)for l in self.convs2:remove_weight_norm(l)class ResBlock2(torch.nn.Module):def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):super(ResBlock2, self).__init__()self.h = hself.convs = nn.ModuleList([weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],padding=get_padding(kernel_size, dilation[0]))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],padding=get_padding(kernel_size, dilation[1])))])self.convs.apply(init_weights)def forward(self, x):for c in self.convs:xt = F.leaky_relu(x, LRELU_SLOPE)xt = c(xt)x = xt + xreturn xdef remove_weight_norm(self):for l in self.convs:remove_weight_norm(l)class Generator(torch.nn.Module):def __init__(self, h):super(Generator, self).__init__()self.h = hself.num_kernels = len(h.resblock_kernel_sizes)self.num_upsamples = len(h.upsample_rates)self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))resblock = ResBlock1 if h.resblock == '1' else ResBlock2self.ups = nn.ModuleList()for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),k, u, padding=(k-u)//2)))self.resblocks = nn.ModuleList()for i in range(len(self.ups)):ch = h.upsample_initial_channel//(2**(i+1))for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):self.resblocks.append(resblock(h, ch, k, d))self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))self.ups.apply(init_weights)self.conv_post.apply(init_weights)def forward(self, x):x = self.conv_pre(x)for i in range(self.num_upsamples):x = F.leaky_relu(x, LRELU_SLOPE)x = self.ups[i](x)xs = Nonefor j in range(self.num_kernels):if xs is None:xs = self.resblocks[i*self.num_kernels+j](x)else:xs += self.resblocks[i*self.num_kernels+j](x)x = xs / self.num_kernelsx = F.leaky_relu(x)x = self.conv_post(x)x = torch.tanh(x)return xdef remove_weight_norm(self):print('Removing weight norm...')for l in self.ups:remove_weight_norm(l)for l in self.resblocks:l.remove_weight_norm()remove_weight_norm(self.conv_pre)remove_weight_norm(self.conv_post)class DiscriminatorP(torch.nn.Module):def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):super(DiscriminatorP, self).__init__()self.period = periodnorm_f = weight_norm if use_spectral_norm == False else spectral_normself.convs = nn.ModuleList([norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),])self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))def forward(self, x):fmap = []# 1d to 2db, c, t = x.shapeif t % self.period != 0: # pad firstn_pad = self.period - (t % self.period)x = F.pad(x, (0, n_pad), "reflect")t = t + n_padx = x.view(b, c, t // self.period, self.period)for l in self.convs:x = l(x)x = F.leaky_relu(x, LRELU_SLOPE)fmap.append(x)x = self.conv_post(x)fmap.append(x)x = torch.flatten(x, 1, -1)return x, fmapclass MultiPeriodDiscriminator(torch.nn.Module):def __init__(self):super(MultiPeriodDiscriminator, self).__init__()self.discriminators = nn.ModuleList([DiscriminatorP(2),DiscriminatorP(3),DiscriminatorP(5),DiscriminatorP(7),DiscriminatorP(11),])def forward(self, y, y_hat):y_d_rs = []y_d_gs = []fmap_rs = []fmap_gs = []for i, d in enumerate(self.discriminators):y_d_r, fmap_r = d(y)y_d_g, fmap_g = d(y_hat)y_d_rs.append(y_d_r)fmap_rs.append(fmap_r)y_d_gs.append(y_d_g)fmap_gs.append(fmap_g)return y_d_rs, y_d_gs, fmap_rs, fmap_gsclass DiscriminatorS(torch.nn.Module):def __init__(self, use_spectral_norm=False):super(DiscriminatorS, self).__init__()norm_f = weight_norm if use_spectral_norm == False else spectral_normself.convs = nn.ModuleList([norm_f(Conv1d(1, 128, 15, 1, padding=7)),norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),])self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))def forward(self, x):fmap = []for l in self.convs:x = l(x)x = F.leaky_relu(x, LRELU_SLOPE)fmap.append(x)x = self.conv_post(x)fmap.append(x)x = torch.flatten(x, 1, -1)return x, fmapclass MultiScaleDiscriminator(torch.nn.Module):def __init__(self):super(MultiScaleDiscriminator, self).__init__()self.discriminators = nn.ModuleList([DiscriminatorS(use_spectral_norm=True),DiscriminatorS(),DiscriminatorS(),])self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2),AvgPool1d(4, 2, padding=2)])def forward(self, y, y_hat):y_d_rs = []y_d_gs = []fmap_rs = []fmap_gs = []for i, d in enumerate(self.discriminators):if i != 0:y = self.meanpools[i-1](y)y_hat = self.meanpools[i-1](y_hat)y_d_r, fmap_r = d(y)y_d_g, fmap_g = d(y_hat)y_d_rs.append(y_d_r)fmap_rs.append(fmap_r)y_d_gs.append(y_d_g)fmap_gs.append(fmap_g)return y_d_rs, y_d_gs, fmap_rs, fmap_gsdef feature_loss(fmap_r, fmap_g):loss = 0for dr, dg in zip(fmap_r, fmap_g):for rl, gl in zip(dr, dg):loss += torch.mean(torch.abs(rl - gl))return loss*2def discriminator_loss(disc_real_outputs, disc_generated_outputs):loss = 0r_losses = []g_losses = []for dr, dg in zip(disc_real_outputs, disc_generated_outputs):r_loss = torch.mean((1-dr)**2)g_loss = torch.mean(dg**2)loss += (r_loss + g_loss)r_losses.append(r_loss.item())g_losses.append(g_loss.item())return loss, r_losses, g_lossesdef generator_loss(disc_outputs):loss = 0gen_losses = []for dg in disc_outputs:l = torch.mean((1-dg)**2)gen_losses.append(l)loss += lreturn loss, gen_losses
这段代码实现了一个基于生成对抗网络(GAN)的音频生成模型,具体结构类似于GHiFi-GAN(一种高性能声码器),主要用于从梅尔频谱(Mel-spectrogram)生成原始音频波形。下面分模块详细解释:
核心组件概览
代码包含生成器(Generator)、两种鉴别器(多周期鉴别器、多尺度鉴别器)、残差块(ResBlock)及对应的损失函数,形成完整的GAN训练框架。
1. 残差块(ResBlock)
在这段音频生成模型代码中,残差块(ResBlock)是生成器提取特征的核心组件,通过残差连接(Residual Connection)缓解深层网络的梯度消失问题,同时利用膨胀卷积(Dilated Convolution)扩大感受野,更有效地捕捉音频的时序依赖关系。代码中实现了两种残差块:ResBlock1
和ResBlock2
,下面分别详细解析。
1. ResBlock1 详解
ResBlock1
是更复杂的残差块结构,包含两组卷积层,通过不同膨胀率的卷积提取多尺度特征,再通过残差连接融合输入与输出。
1.1 初始化方法(init)
class ResBlock1(torch.nn.Module):def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):super(ResBlock1, self).__init__()self.h = h # 模型超参数配置(未直接使用,预留扩展)# 第一组卷积:带不同膨胀率的1D卷积self.convs1 = nn.ModuleList([weight_norm(Conv1d(channels, # 输入通道数channels, # 输出通道数(与输入相同,保证残差连接维度匹配)kernel_size, # 卷积核大小(如3)stride=1, # 步长1(不改变时间维度)dilation=dilation[0], # 膨胀率(控制感受野)padding=get_padding(kernel_size, dilation[0]) # 自动计算填充,保证输出长度不变)),# 重复定义另外两个卷积层,使用不同膨胀率dilation[1]和dilation[2]weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],padding=get_padding(kernel_size, dilation[1]))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],padding=get_padding(kernel_size, dilation[2])))])self.convs1.apply(init_weights) # 初始化卷积层权重# 第二组卷积:固定膨胀率=1的1D卷积(普通卷积)self.convs2 = nn.ModuleList([weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1))),weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1)))])self.convs2.apply(init_weights) # 初始化卷积层权重
关键细节:
- 膨胀卷积(Dilated Convolution):
convs1
的三个卷积层分别使用dilation=(1,3,5)
,膨胀率越大,感受野(Receptive Field)越大(无需增加卷积核大小即可捕捉更长时序的依赖关系),适合音频这种长时序数据。 - 权重归一化(weight_norm):对卷积层应用权重归一化,稳定训练过程(减少梯度波动,加速收敛)。
- 填充计算(get_padding):通过
get_padding(kernel_size, dilation)
自动计算填充大小,确保卷积后时间维度不变(输入输出长度相同,满足残差连接的维度匹配)。 - 两组卷积设计:
convs1
(膨胀卷积)负责扩大感受野提取多尺度特征,convs2
(普通卷积)负责特征细化,增强特征表达能力。
1.2 前向传播(forward)
def forward(self, x):for c1, c2 in zip(self.convs1, self.convs2):xt = F.leaky_relu(x, LRELU_SLOPE) # 激活函数(LeakyReLU,斜率0.1)xt = c1(xt) # 第一组膨胀卷积xt = F.leaky_relu(xt, LRELU_SLOPE) # 再次激活xt = c2(xt) # 第二组普通卷积x = xt + x # 残差连接:当前输出 + 原始输入return x
数据流动过程:
- 输入
x
先通过LeakyReLU激活(引入非线性)。 - 经过
convs1
的膨胀卷积提取多尺度特征。 - 再次激活后,经过
convs2
的普通卷积细化特征。 - 将卷积结果
xt
与原始输入x
相加(残差连接),得到当前残差块的输出。 - 重复上述过程(共3次,与
convs1
/convs2
的长度一致),逐步强化特征。
残差连接的作用:直接将输入x
加到输出xt
中,避免深层网络的梯度消失(梯度可通过x
直接反向传播),同时保留原始特征,增强模型对细微特征的捕捉能力。
1.3 移除权重归一化(remove_weight_norm)
def remove_weight_norm(self):for l in self.convs1:remove_weight_norm(l)for l in self.convs2:remove_weight_norm(l)
在模型推理(生成音频)阶段,移除权重归一化可减少计算量,提高推理速度(训练时需要权重归一化稳定训练,推理时无需)。
2. ResBlock2 详解
ResBlock2
是简化版的残差块,仅包含一组卷积层(膨胀卷积),计算量更小,适合对效率要求较高的场景。
2.1 初始化方法(init)
class ResBlock2(torch.nn.Module):def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):super(ResBlock2, self).__init__()self.h = h # 超参数配置(预留扩展)self.convs = nn.ModuleList([weight_norm(Conv1d(channels,channels,kernel_size,stride=1,dilation=dilation[0], # 第一个膨胀率padding=get_padding(kernel_size, dilation[0]))),weight_norm(Conv1d(channels,channels,kernel_size,stride=1,dilation=dilation[1], # 第二个膨胀率padding=get_padding(kernel_size, dilation[1])))])self.convs.apply(init_weights) # 初始化权重
与ResBlock1的差异:
- 仅包含一组卷积层
convs
(长度为2,与dilation=(1,3)
匹配),无ResBlock1
中的第二组普通卷积,结构更简单。 - 膨胀率通常较小(如
(1,3)
),感受野扩展更温和,计算量更低。
2.2 前向传播(forward)
def forward(self, x):for c in self.convs:xt = F.leaky_relu(x, LRELU_SLOPE) # 激活xt = c(xt) # 膨胀卷积x = xt + x # 残差连接return x
数据流动过程:
- 输入
x
通过LeakyReLU激活。 - 经过
convs
的膨胀卷积提取特征。 - 卷积结果
xt
与原始输入x
相加(残差连接)。 - 重复上述过程(共2次,与
convs
的长度一致)。
简化的意义:减少卷积层数量,降低计算复杂度,同时保留残差连接的核心优势(缓解梯度消失),适合资源有限的场景或作为轻量化模型的组件。
2.3 移除权重归一化(remove_weight_norm)
def remove_weight_norm(self):for l in self.convs:remove_weight_norm(l)
与ResBlock1
同理,推理阶段移除权重归一化以提高效率。
3. 两种残差块的对比与应用
特性 | ResBlock1 | ResBlock2 |
---|---|---|
卷积层组数 | 2组(膨胀卷积+普通卷积) | 1组(仅膨胀卷积) |
卷积层数量 | 3+3=6层 | 2层 |
感受野 | 更大(多组膨胀率+普通卷积) | 较小(仅两组膨胀率) |
计算量 | 较高 | 较低 |
适用场景 | 追求高特征表达能力(如高质量生成) | 追求效率(如快速推理) |
在生成器中,通过参数h.resblock
选择使用ResBlock1
或ResBlock2
,两者均作为特征提取的基本单元,在每个上采样步骤后堆叠,逐步将梅尔频谱的特征转换为音频波形的特征。
总结
残差块是该音频生成模型的核心组件,通过:
- 残差连接:解决深层网络梯度消失问题,保留原始特征。
- 膨胀卷积:在不增加卷积核大小的情况下扩大感受野,捕捉音频的长时序依赖。
- 权重归一化:稳定训练过程,加速收敛。
ResBlock1
和ResBlock2
分别从“特征表达能力”和“计算效率”角度设计,可根据实际需求选择,共同支撑生成器从梅尔频谱到音频波形的高质量转换。
2. 生成器(Generator)
生成器(Generator)是该音频生成模型的核心组件,负责将输入的80维梅尔频谱(Mel-spectrogram)转换为1维原始音频波形。其设计核心是通过多步上采样逐步扩大时间维度(从梅尔频谱的短时长相音频的长时长),并通过残差块提取和强化特征,最终输出高质量音频。以下是详细解析:
1. 生成器的初始化(__init__方法)
生成器的初始化过程定义了从输入映射、上采样、特征提取到输出映射的完整组件链,核心参数依赖于配置h
(包含上采样率、卷积核大小等超参数)。
class Generator(torch.nn.Module):def __init__(self, h):super(Generator, self).__init__()self.h = h # 模型超参数配置(如采样率、卷积核大小等)self.num_kernels = len(h.resblock_kernel_sizes) # 每个上采样步骤对应的残差块数量self.num_upsamples = len(h.upsample_rates) # 上采样总步数# 输入映射:将80维梅尔频谱转换为高维特征self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))# 上采样层:通过转置卷积实现时间维度扩展self.ups = nn.ModuleList()for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):self.ups.append(weight_norm(ConvTranspose1d(# 输入通道数:初始通道数 // 2^i(每次上采样后通道数减半)h.upsample_initial_channel // (2 **i),# 输出通道数:初始通道数 // 2^(i+1)h.upsample_initial_channel // (2** (i + 1)),kernel_size=k, # 上采样卷积核大小stride=u, # 上采样倍数(与upsample_rates对应)padding=(k - u) // 2 # 计算填充,确保上采样后时间维度正确扩展)))# 残差块组:每个上采样步骤后接多组残差块,用于特征提取self.resblocks = nn.ModuleList()for i in range(len(self.ups)):# 当前上采样步骤后的特征通道数(随上采样逐步减半)ch = h.upsample_initial_channel // (2 **(i + 1))# 为每个上采样步骤添加num_kernels个残差块for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):# 根据配置选择ResBlock1或ResBlock2resblock = ResBlock1 if h.resblock == '1' else ResBlock2self.resblocks.append(resblock(h, ch, k, d))# 输出映射:将高维特征转换为1维音频波形self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))# 初始化权重self.ups.apply(init_weights)self.conv_post.apply(init_weights)
关键组件解析
1.输入映射(conv_pre)- 作用:将80维梅尔频谱(输入特征)映射到高维特征空间(通道数为h.upsample_initial_channel
,如512),为后续特征提取做准备。
- 实现:1D卷积(
Conv1d
),卷积核大小7,padding=3,确保时间维度不变(输入输出长度相同)。 - 权重归一化:应用
weight_norm
稳定训练。
2.上采样层(self.ups)- 作用:通过转置卷积(ConvTranspose1d) 逐步扩大时间维度(梅尔频谱的时间步长较短,音频的时间步长较长,需通过上采样匹配)。
- 核心参数:
upsample_rates
:上采样倍数列表(如[8,8,2,2]
),总上采样倍数为各值乘积(8×8×2×2=256,即梅尔频谱长度×256=音频长度)。upsample_kernel_sizes
:上采样卷积核大小(需与上采样率匹配,如[16,16,4,4]
),确保通过padding计算((k-u)//2
)使时间维度按u
倍扩展。
- 通道数变化:每次上采样后通道数减半(如512→256→128→64→32),平衡计算量与特征表达能力。
3.残差块组(self.resblocks)- 作用:对每个上采样步骤后的特征进行细化提取,捕捉音频的局部与全局时序依赖。
- 组织方式:每个上采样步骤后接
num_kernels
个残差块(如3个),残差块类型由h.resblock
决定(ResBlock1
或ResBlock2
)。 - 通道一致性:残差块的输入/输出通道数与当前上采样步骤后的通道数
ch
一致,确保残差连接有效。
4.输出映射(conv_post)- 作用:将最终的高维特征(如32通道)转换为1维音频波形。
- 实现:1D卷积(
Conv1d
),卷积核大小7,padding=3,输出通道数=1。
2. 前向传播(forward方法)
前向传播定义了数据从梅尔频谱输入到音频输出的完整流动过程,核心是“上采样→残差特征提取→特征融合”的迭代过程。
def forward(self, x):# 步骤1:输入映射(梅尔频谱→高维特征)x = self.conv_pre(x) # 形状:(batch, 80, T_mel) → (batch, initial_ch, T_mel)# 步骤2:多步上采样+残差特征提取for i in range(self.num_upsamples):x = F.leaky_relu(x, LRELU_SLOPE) # 激活函数(引入非线性)x = self.ups[i](x) # 上采样:时间维度扩大u倍,通道数减半# 多个残差块并行处理,结果平均融合xs = Nonefor j in range(self.num_kernels):# 取出当前上采样步骤对应的第j个残差块resblock = self.resblocks[i * self.num_kernels + j]if xs is None:xs = resblock(x) # 首次处理:直接赋值else:xs += resblock(x) # 后续处理:累加特征x = xs / self.num_kernels # 特征融合(平均):降低过拟合风险# 步骤3:输出映射(高维特征→音频波形)x = F.leaky_relu(x) # 最终激活x = self.conv_post(x) # 形状:(batch, ch, T_audio) → (batch, 1, T_audio)x = torch.tanh(x) # 输出范围归一化到[-1, 1](音频信号常见范围)return x
数据流动细节
-
输入阶段:输入
x
为梅尔频谱,形状为(batch_size, 80, T_mel)
(80是梅尔频谱维度,T_mel
是时间步长)。经过conv_pre
后,形状变为(batch_size, initial_ch, T_mel)
(如(32, 512, 100)
)。 -
上采样阶段:
- 每次上采样通过
self.ups[i]
将时间维度扩大h.upsample_rates[i]
倍(如8倍),通道数减半(如512→256)。 - 上采样后,通过
num_kernels
个残差块并行处理(如3个),每个残差块输出相同形状的特征,累加后平均(xs / num_kernels
),实现多尺度特征融合。
- 每次上采样通过
-
输出阶段:最终特征经过
conv_post
转换为1通道,再通过tanh
激活,输出形状为(batch_size, 1, T_audio)
的音频波形(T_audio = T_mel × 总上采样倍数
)。
3. 移除权重归一化(remove_weight_norm方法)
训练时为稳定收敛使用了权重归一化,但推理(生成音频)时无需,因此提供该方法移除归一化以提高效率:
def remove_weight_norm(self):print('Removing weight norm...')for l in self.ups:remove_weight_norm(l) # 移除上采样层的权重归一化for l in self.resblocks:l.remove_weight_norm() # 移除残差块的权重归一化remove_weight_norm(self.conv_pre) # 移除输入映射的权重归一化remove_weight_norm(self.conv_post) # 移除输出映射的权重归一化
4. 生成器设计亮点
1.渐进式上采样:通过多步小倍数上采样(而非一步大倍数上采样),避免直接扩展导致的特征模糊,逐步恢复高频细节。
2.残差特征融合:每个上采样步骤后用多个残差块并行处理并平均结果,融合多尺度特征,增强生成音频的丰富性。
3.膨胀卷积应用:残差块中使用膨胀卷积,在不增加计算量的情况下扩大感受野,有效捕捉音频的长时序依赖(如语音中的上下文信息)。
4.权重归一化:稳定训练过程,减少梯度波动,使深层网络更容易收敛。
总结
生成器通过“输入映射→多步上采样+残差特征提取→输出映射”的流程,将梅尔频谱转换为音频波形。核心设计围绕“渐进式扩展时间维度”和“多尺度特征融合”,结合残差连接和膨胀卷积,在保证生成质量的同时,平衡了计算效率与训练稳定性。这一结构使其特别适合作为语音合成系统中的声码器(Vocoder),生成高保真、自然的音频。
3. 鉴别器(Discriminator)
在该音频生成模型中,鉴别器(Discriminator)的核心作用是区分“真实音频”和“生成器输出的伪造音频”,通过与生成器的对抗训练,推动生成器生成更逼真的音频。代码中设计了两种互补的鉴别器结构:多周期鉴别器(MultiPeriodDiscriminator) 和多尺度鉴别器(MultiScaleDiscriminator),从不同角度捕捉音频特征,增强判别能力。以下是详细解析:
1. 多周期鉴别器(MultiPeriodDiscriminator)
多周期鉴别器通过周期子序列分割捕捉音频的周期性模式(如语音的基频、音乐的节奏等),从多个周期尺度鉴别音频真实性。它包含多个子鉴别器DiscriminatorP
,每个子鉴别器专注于特定周期的特征。
1.1 子鉴别器(DiscriminatorP)
DiscriminatorP
是多周期鉴别器的基本单元,针对特定周期period
处理音频,将1D音频转换为2D周期特征后进行判别。
初始化(init)
class DiscriminatorP(torch.nn.Module):def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):super(DiscriminatorP, self).__init__()self.period = period # 周期(如2、3、5等,用于分割音频为子序列)# 选择权重归一化或谱归一化(谱归一化更适合稳定GAN训练)norm_f = weight_norm if use_spectral_norm == False else spectral_norm# 2D卷积层序列:逐步提取周期特征,通道数从1→32→128→512→1024self.convs = nn.ModuleList([norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), # 步长1,不改变空间维度])# 输出层:将特征映射为判别分数(真实/伪造)self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
关键设计:
- 周期分割:针对特定周期
period
(如2),将音频分割为长度为period
的子序列(如音频[t0,t1,t2,t3]
按周期2分割为[[t0,t1], [t2,t3]]
),转换为2D特征(形状:(batch, 1, 子序列数, period)
),便于捕捉周期性模式。 - 2D卷积:使用
Conv2d
处理周期特征,卷积核形状为(kernel_size, 1)
,仅在子序列数维度(时间方向)滑动,保留周期内的时序关系。 - 归一化选择:支持
weight_norm
(权重归一化)和spectral_norm
(谱归一化),后者通过限制权重矩阵的谱范数,更能稳定GAN训练。
前向传播(forward)
def forward(self, x):fmap = [] # 存储各层特征图(用于特征匹配损失)# 步骤1:1D音频→2D周期特征(按周期分割)b, c, t = x.shape # x形状:(batch, 1, 时间步长)if t % self.period != 0: # 补零使时间步长为周期的整数倍n_pad = self.period - (t % self.period)x = F.pad(x, (0, n_pad), "reflect") # 反射补零(减少边界效应)t = t + n_pad# 重塑为2D:(batch, 1, 子序列数, 周期长度)x = x.view(b, c, t // self.period, self.period)# 步骤2:通过卷积层提取特征并记录特征图for l in self.convs:x = l(x) # 卷积操作x = F.leaky_relu(x, LRELU_SLOPE) # 激活函数(LeakyReLU,斜率0.1)fmap.append(x) # 保存当前层特征图# 步骤3:输出判别分数x = self.conv_post(x) # 映射为判别分数(形状:(batch, 1, ...))fmap.append(x) # 保存输出层特征图x = torch.flatten(x, 1, -1) # 展平为(batch, 分数)return x, fmap # 返回判别分数和特征图列表
核心流程:
- 周期转换:将1D音频转换为2D周期特征,突出周期性模式(如语音的基频周期)。
- 特征提取:通过多组2D卷积逐步提升通道数(1→1024),压缩时间维度(子序列数减少),捕捉高层周期特征。
- 判别输出:最终通过
conv_post
输出判别分数(值越大越可能是真实音频),同时记录各层特征图用于后续损失计算。
1.2 多周期鉴别器整体(MultiPeriodDiscriminator)
多周期鉴别器由多个不同周期的DiscriminatorP
组成,从多个周期尺度(2、3、5、7、11)联合判别,覆盖音频中不同频率的周期性模式。
class MultiPeriodDiscriminator(torch.nn.Module):def __init__(self):super(MultiPeriodDiscriminator, self).__init__()# 包含5个不同周期的子鉴别器(周期2、3、5、7、11)self.discriminators = nn.ModuleList([DiscriminatorP(2),DiscriminatorP(3),DiscriminatorP(5),DiscriminatorP(7),DiscriminatorP(11),])def forward(self, y, y_hat):# y:真实音频;y_hat:生成器输出的伪造音频y_d_rs = [] # 真实音频的判别分数列表y_d_gs = [] # 伪造音频的判别分数列表fmap_rs = [] # 真实音频的特征图列表fmap_gs = [] # 伪造音频的特征图列表# 每个子鉴别器分别处理真实和伪造音频for i, d in enumerate(self.discriminators):y_d_r, fmap_r = d(y) # 真实音频的判别结果y_d_g, fmap_g = d(y_hat) # 伪造音频的判别结果y_d_rs.append(y_d_r)fmap_rs.append(fmap_r)y_d_gs.append(y_d_g)fmap_gs.append(fmap_g)return y_d_rs, y_d_gs, fmap_rs, fmap_gs
设计目的:不同音频(如语音、音乐)的周期性模式不同(如语音基频约50-500Hz,对应周期20-2ms),使用多个周期(2、3、5、7、11)可覆盖更广泛的周期范围,避免单一周期的判别偏差,提升整体判别能力。
2. 多尺度鉴别器(MultiScaleDiscriminator)
多尺度鉴别器通过下采样生成不同时间尺度的音频,从多个分辨率(原始、1/2、1/4)捕捉音频的局部细节和全局结构,与多周期鉴别器形成互补。它包含多个子鉴别器DiscriminatorS
,每个子鉴别器处理特定尺度的音频。
2.1 子鉴别器(DiscriminatorS)
DiscriminatorS
是多尺度鉴别器的基本单元,使用1D卷积直接处理音频,通过步长和分组卷积提取不同尺度的特征。
初始化(init)
class DiscriminatorS(torch.nn.Module):def __init__(self, use_spectral_norm=False):super(DiscriminatorS, self).__init__()norm_f = weight_norm if use_spectral_norm == False else spectral_norm# 1D卷积层序列:逐步提取时序特征,通道数从1→128→256→512→1024self.convs = nn.ModuleList([norm_f(Conv1d(1, 128, 15, 1, padding=7)), # 步长1,不改变时间维度norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), # 步长2下采样,分组卷积norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), # 步长2下采样norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), # 步长4下采样norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), # 步长4下采样norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), # 步长1norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), # 步长1,小卷积核细化特征])# 输出层:映射为判别分数self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
关键设计:
- 1D卷积直接处理:无需周期分割,直接对1D音频进行卷积,更侧重捕捉时序连续性特征(如音频的瞬态变化)。
- 下采样与分组卷积:通过步长(2、4)实现时间维度下采样(降低分辨率),同时使用分组卷积(
groups
)减少参数计算量,增强特征多样性。 - 多尺度特征:卷积核大小从15→41→5,结合不同步长,捕捉从局部到全局的时序特征。
前向传播(forward)
def forward(self, x):fmap = [] # 存储各层特征图# 步骤1:通过卷积层提取特征并记录特征图for l in self.convs:x = l(x) # 卷积操作x = F.leaky_relu(x, LRELU_SLOPE) # 激活函数fmap.append(x) # 保存当前层特征图# 步骤2:输出判别分数x = self.conv_post(x) # 映射为判别分数fmap.append(x) # 保存输出层特征图x = torch.flatten(x, 1, -1) # 展平为(batch, 分数)return x, fmap # 返回判别分数和特征图列表
核心流程:直接对1D音频进行多步卷积和下采样,逐步压缩时间维度、提升通道数,捕捉不同尺度的时序特征,最终输出判别分数和特征图。
2.2 多尺度鉴别器整体(MultiScaleDiscriminator)
多尺度鉴别器由3个DiscriminatorS
组成,通过平均池化生成不同尺度的音频(原始、1/2、1/4),从粗到细覆盖音频的全局和局部特征。
class MultiScaleDiscriminator(torch.nn.Module):def __init__(self):super(MultiScaleDiscriminator, self).__init__()# 3个子鉴别器(第1个使用谱归一化,增强稳定性)self.discriminators = nn.ModuleList([DiscriminatorS(use_spectral_norm=True),DiscriminatorS(),DiscriminatorS(),])# 平均池化层:用于生成低尺度音频(1/2、1/4)self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), # 下采样1/2(核4,步长2)AvgPool1d(4, 2, padding=2) # 再下采样1/2(总1/4)])def forward(self, y, y_hat):y_d_rs = [] # 真实音频的判别分数列表y_d_gs = [] # 伪造音频的判别分数列表fmap_rs = [] # 真实音频的特征图列表fmap_gs = [] # 伪造音频的特征图列表# 每个子鉴别器处理不同尺度的音频for i, d in enumerate(self.discriminators):if i != 0: # 第1个处理原始尺度,第2/3个处理下采样后的尺度y = self.meanpools[i-1](y) # 真实音频下采样y_hat = self.meanpools[i-1](y_hat) # 伪造音频下采样# 子鉴别器处理当前尺度的音频y_d_r, fmap_r = d(y)y_d_g, fmap_g = d(y_hat)y_d_rs.append(y_d_r)fmap_rs.append(fmap_r)y_d_gs.append(y_d_g)fmap_gs.append(fmap_g)return y_d_rs, y_d_gs, fmap_rs, fmap_gs
设计目的:音频的高频细节(如瞬态音)和低频结构(如整体节奏)需要不同分辨率的特征捕捉。通过平均池化生成1/2、1/4尺度的音频,使子鉴别器专注于不同频率范围的特征,提升对细微差异的判别能力。
3. 鉴别器的协同作用与训练目标
两种鉴别器(多周期+多尺度)从不同角度判别音频真实性:
- 多周期鉴别器:聚焦周期性模式(如语音基频、音乐节拍),擅长捕捉“韵律一致性”。
- 多尺度鉴别器:聚焦时序连续性(如音频的平滑过渡、瞬态变化),擅长捕捉“细节真实性”。
训练时,鉴别器的目标是最大化对真实音频的判别分数(接近1),最小化对伪造音频的判别分数(接近0);而生成器则通过对抗训练,尝试欺骗鉴别器(使伪造音频的判别分数接近1)。同时,鉴别器输出的特征图用于计算“特征匹配损失”,进一步约束生成音频的特征分布与真实音频一致。
总结
鉴别器通过“多周期+多尺度”的组合设计,从周期性和时序连续性两个维度全面判别音频真实性:
- 每个子鉴别器通过卷积层提取特征,输出判别分数和特征图。
- 多周期设计覆盖不同周期模式,多尺度设计覆盖不同分辨率特征。
- 与生成器的对抗训练推动生成器生成更逼真、细节更丰富的音频。
这种结构是高性能音频生成模型(如GHiFi-GAN)的核心设计,使其能够生成接近真实的高保真音频。
4. 损失函数
在该音频生成GAN模型中,损失函数是连接生成器(Generator)和鉴别器(Discriminator)的核心,通过对抗训练推动双方迭代优化。代码中定义了三类关键损失函数:** 特征匹配损失(feature_loss)、 鉴别器损失(discriminator_loss)和 生成器损失(generator_loss)**,它们协同作用以确保生成音频的真实性和质量。以下是详细解析:
1. 特征匹配损失(feature_loss)
特征匹配损失用于约束生成音频的特征分布与真实音频一致,补充对抗损失的不足(仅靠对抗损失可能导致生成样本“骗过”鉴别器但特征不真实)。它通过计算真实音频和生成音频在鉴别器各层特征图的差异,引导生成器学习更细腻的特征。
def feature_loss(fmap_r, fmap_g):loss = 0# 遍历所有鉴别器的特征图列表(多周期+多尺度鉴别器的特征图)for dr, dg in zip(fmap_r, fmap_g):# 遍历单个鉴别器内的各层特征图for rl, gl in zip(dr, dg):# 计算当前层特征图的L1损失(平均绝对误差)loss += torch.mean(torch.abs(rl - gl))# 缩放损失值(经验系数,增强该损失的权重)return loss * 2
关键细节
- 输入:
fmap_r
是真实音频经过鉴别器后输出的特征图列表,fmap_g
是生成音频经过相同鉴别器后输出的特征图列表(包含多周期和多尺度鉴别器的所有层特征)。 - 计算方式:对每一层特征图(
rl
为真实特征,gl
为生成特征)计算L1损失(torch.abs(rl - gl)
的均值),累加所有层的损失后乘以2(缩放系数,平衡与其他损失的权重)。 - 作用:强制生成音频在鉴别器的中间特征层面与真实音频相似,避免生成器仅优化“骗过鉴别器”的表层特征,而忽略音频的细节结构(如频谱分布、时序连贯性)。
2. 鉴别器损失(discriminator_loss)
鉴别器的目标是最大化对“真实音频”的判别分数(接近1),同时最小化对“生成音频”的判别分数(接近0)。该损失函数量化了鉴别器的分类误差,指导其优化以更好地区分真假音频。
def discriminator_loss(disc_real_outputs, disc_generated_outputs):loss = 0r_losses = [] # 记录每个鉴别器对真实音频的损失g_losses = [] # 记录每个鉴别器对生成音频的损失# 遍历所有鉴别器的输出(多周期+多尺度鉴别器)for dr, dg in zip(disc_real_outputs, disc_generated_outputs):# 真实音频损失:希望判别分数dr接近1,用(1-dr)^2衡量偏差r_loss = torch.mean((1 - dr) **2)# 生成音频损失:希望判别分数dg接近0,用dg^2衡量偏差g_loss = torch.mean(dg** 2)# 累加单个鉴别器的总损失loss += (r_loss + g_loss)# 记录单个鉴别器的损失值(用于监控训练过程)r_losses.append(r_loss.item())g_losses.append(g_loss.item())return loss, r_losses, g_losses
关键细节
- 输入:
disc_real_outputs
是所有鉴别器对真实音频的判别分数列表,disc_generated_outputs
是所有鉴别器对生成音频的判别分数列表。 - 计算方式:
- 对真实音频:使用平方损失
(1 - dr)^2
,当dr=1
时损失为0(完美判别),dr
越小损失越大。 - 对生成音频:使用平方损失
dg^2
,当dg=0
时损失为0(完美判别),dg
越大损失越大。 - 总损失为所有鉴别器的真实损失与生成损失之和。
- 对真实音频:使用平方损失
- 作用:推动鉴别器学习真实音频与生成音频的差异,提升分类能力。每个鉴别器(多周期/多尺度)的损失被单独记录,便于监控不同鉴别器的训练状态。
3. 生成器损失(generator_loss)
生成器的目标是“欺骗”鉴别器,使鉴别器对生成音频的判别分数接近1。该损失函数量化了生成器的欺骗效果,指导其优化以生成更逼真的音频。
def generator_loss(disc_outputs):loss = 0gen_losses = [] # 记录每个鉴别器上的生成损失# 遍历所有鉴别器对生成音频的输出for dg in disc_outputs:# 生成损失:希望判别分数dg接近1,用(1-dg)^2衡量偏差l = torch.mean((1 - dg) **2)gen_losses.append(l)loss += lreturn loss, gen_losses
####** 关键细节 - 输入 :disc_outputs
是所有鉴别器对生成音频的判别分数列表(与disc_generated_outputs
一致)。
- 计算方式 :对每个鉴别器的生成音频判别分数dg
,使用平方损失(1 - dg)^2
,当dg=1
时损失为0(完美欺骗),dg
越小损失越大。总损失为所有鉴别器的生成损失之和。
- 作用 **:推动生成器优化输出,使生成音频在所有鉴别器(多周期/多尺度)上都被误认为真实音频,迫使生成器学习真实音频的全面特征(周期性、时序连续性等)。
###** 4. 损失函数的协同作用 在实际训练中,三类损失函数通过以下方式协同工作:
1. 鉴别器优化 :单独最小化discriminator_loss
,使其能更准确地区分真假音频。
2. 生成器优化 **:最小化“生成器损失 + 特征匹配损失”(通常特征匹配损失会乘以一个权重系数,如10),既要求生成音频能欺骗鉴别器(对抗目标),又要求其特征分布接近真实音频(特征匹配目标)。
这种组合避免了GAN训练中常见的“模式崩溃”(生成样本多样性不足)和“训练不稳定”问题,同时保证了生成音频的高质量(细节丰富、真实感强)。
###** 总结 - 特征匹配损失 :从特征层面约束生成音频与真实音频的一致性,提升细节质量。
- 鉴别器损失 :指导鉴别器学习真假音频的差异,增强判别能力。
- 生成器损失 **:指导生成器欺骗鉴别器,推动生成更逼真的音频。
三者协同形成了完整的训练目标,使生成器能够逐步学习真实音频的分布,最终生成高保真、自然的音频波形。
总结
该代码实现了一个高性能音频生成模型,核心设计包括:
- 生成器:通过多步上采样和残差块,从梅尔频谱生成音频。
- 鉴别器:多周期+多尺度设计,从不同角度区分真假音频,增强判别能力。
- 损失函数:结合对抗损失和特征匹配损失,平衡生成质量和训练稳定性。
这种结构常用于语音合成系统中的声码器(如TTS中的最后一步:从梅尔频谱生成波形),能生成高质量、高保真的音频。