hubert模型代码分析
首先给出模型结构图:
然后给出完整的代码:
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tupleimport numpy as np
import torch
import torch.nn as nn
from omegaconf import IIfrom fairseq import utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec.wav2vec2 import (EXTRACTOR_MODE_CHOICES,MASKING_DISTRIBUTION_CHOICES,LAYER_TYPE_CHOICES,ConvFeatureExtractionModel,TransformerEncoder,
)
from fairseq.modules import GradMultiply, LayerNorm
from fairseq.tasks.hubert_pretraining import (HubertPretrainingConfig,HubertPretrainingTask,
)logger = logging.getLogger(__name__)@dataclass
class HubertConfig(FairseqDataclass):label_rate: float = II("task.label_rate")extractor_mode: EXTRACTOR_MODE_CHOICES = field(default="default",metadata={"help": "mode for feature extractor. default has a single group ""norm with d groups in the first conv block, whereas layer_norm ""has layer norms in every block (meant to use with normalize=True)"},)encoder_layers: int = field(default=12, metadata={"help": "num encoder layers in the transformer"})encoder_embed_dim: int = field(default=768, metadata={"help": "encoder embedding dimension"})encoder_ffn_embed_dim: int = field(default=3072, metadata={"help": "encoder embedding dimension for FFN"})encoder_attention_heads: int = field(default=12, metadata={"help": "num encoder attention heads"})activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(default="gelu", metadata={"help": "activation function to use"})layer_type: LAYER_TYPE_CHOICES = field(default="transformer", metadata={"help": "layer type in encoder"})# dropoutsdropout: float = field(default=0.1,metadata={"help": "dropout probability for the transformer"},)attention_dropout: float = field(default=0.1,metadata={"help": "dropout probability for attention weights"},)activation_dropout: float = field(default=0.0,metadata={"help": "dropout probability after activation in FFN"},)encoder_layerdrop: float = field(default=0.0,metadata={"help": "probability of dropping a tarnsformer layer"},)dropout_input: float = field(default=0.0,metadata={"help": "dropout to apply to the input (after feat extr)"},)dropout_features: float = field(default=0.0,metadata={"help": "dropout to apply to the features (after feat extr)"},)final_dim: int = field(default=0,metadata={"help": "project final representations and targets to this many ""dimensions. set to encoder_embed_dim is <= 0"},)untie_final_proj: bool = field(default=False,metadata={"help": "use separate projection for each target"},)layer_norm_first: bool = field(default=False,metadata={"help": "apply layernorm first in the transformer"},)conv_feature_layers: str = field(default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",metadata={"help": "string describing convolutional feature extraction ""layers in form of a python list that contains ""[(dim, kernel_size, stride), ...]"},)conv_bias: bool = field(default=False, metadata={"help": "include bias in conv encoder"})logit_temp: float = field(default=0.1, metadata={"help": "temperature to divide logits by"})target_glu: bool = field(default=False, metadata={"help": "adds projection + glu to targets"})feature_grad_mult: float = field(default=1.0,metadata={"help": "multiply feature extractor var grads by this"},)# maskingmask_length: int = field(default=10, metadata={"help": "mask length"})mask_prob: float = field(default=0.65,metadata={"help": "probability of replacing a token with mask"},)mask_selection: MASKING_DISTRIBUTION_CHOICES = field(default="static", metadata={"help": "how to choose mask length"})mask_other: float = field(default=0,metadata={"help": "secondary mask argument ""(used for more complex distributions), ""see help in compute_mask_indicesh"},)no_mask_overlap: bool = field(default=False, metadata={"help": "whether to allow masks to overlap"})mask_min_space: int = field(default=1,metadata={"help": "min space between spans (if no overlap is enabled)"},)# channel maskingmask_channel_length: int = field(default=10,metadata={"help": "length of the mask for features (channels)"},)mask_channel_prob: float = field(default=0.0,metadata={"help": "probability of replacing a feature with 0"},)mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(default="static",metadata={"help": "how to choose mask length for channel masking"},)mask_channel_other: float = field(default=0,metadata={"help": "secondary mask argument ""(used for more complex distributions), ""see help in compute_mask_indicesh"},)no_mask_channel_overlap: bool = field(default=False,metadata={"help": "whether to allow channel masks to overlap"},)mask_channel_min_space: int = field(default=1,metadata={"help": "min space between spans (if no overlap is enabled)"},)# positional embeddingsconv_pos: int = field(default=128,metadata={"help": "number of filters for convolutional positional embeddings"},)conv_pos_groups: int = field(default=16,metadata={"help": "number of groups for convolutional positional embedding"},)conv_pos_batch_norm: bool = field(default=False,metadata={"help": "use batch norm instead of weight norm in conv_pos (for bf16 models)"},)latent_temp: Tuple[float, float, float] = field(default=(2, 0.5, 0.999995),metadata={"help": "legacy (to be removed)"},)# loss computationskip_masked: bool = field(default=False,metadata={"help": "skip computing losses over masked frames"},)skip_nomask: bool = field(default=False,metadata={"help": "skip computing losses over unmasked frames"},)checkpoint_activations: bool = field(default=False,metadata={"help": "recompute activations and save memory for extra compute"},)# FP16 optimizationrequired_seq_len_multiple: int = field(default=2,metadata={"help": "pad the input to encoder such that the sequence length is divisible by multiple"},)# Conformerdepthwise_conv_kernel_size: int = field(default=31,metadata={"help": "depthwise-conv-kernel-size for convolution in conformer layer"},)attn_type: str = field(default="",metadata={"help": "if espnet use ESPNET MHA"},)pos_enc_type: str = field(default="abs",metadata={"help": "Positional encoding type to use in conformer"},)fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})@register_model("hubert", dataclass=HubertConfig)
class HubertModel(BaseFairseqModel):def __init__(self,cfg: HubertConfig,task_cfg: HubertPretrainingConfig,dictionaries: List[Dictionary],) -> None:super().__init__()logger.info(f"HubertModel Config: {cfg}")feature_enc_layers = eval(cfg.conv_feature_layers) # noqaself.embed = feature_enc_layers[-1][0]self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers,dropout=0.0,mode=cfg.extractor_mode,conv_bias=cfg.conv_bias,)feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rateself.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)if self.embed != cfg.encoder_embed_dimelse None)self.mask_prob = cfg.mask_probself.mask_selection = cfg.mask_selectionself.mask_other = cfg.mask_otherself.mask_length = cfg.mask_lengthself.no_mask_overlap = cfg.no_mask_overlapself.mask_min_space = cfg.mask_min_spaceself.mask_channel_prob = cfg.mask_channel_probself.mask_channel_selection = cfg.mask_channel_selectionself.mask_channel_other = cfg.mask_channel_otherself.mask_channel_length = cfg.mask_channel_lengthself.no_mask_channel_overlap = cfg.no_mask_channel_overlapself.mask_channel_min_space = cfg.mask_channel_min_spaceself.dropout_input = nn.Dropout(cfg.dropout_input)self.dropout_features = nn.Dropout(cfg.dropout_features)self.feature_grad_mult = cfg.feature_grad_multself.logit_temp = cfg.logit_tempself.skip_masked = cfg.skip_maskedself.skip_nomask = cfg.skip_nomaskfinal_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dimself.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())self.encoder = TransformerEncoder(cfg)self.layer_norm = LayerNorm(self.embed)self.target_glu = Noneif cfg.target_glu:self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())self.untie_final_proj = cfg.untie_final_projif self.untie_final_proj:self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * len(dictionaries))else:self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)# modules below are not needed during fine-tuningif any([d is None for d in dictionaries]):logger.info("cannot find dictionary. assume will be used for fine-tuning")else:self.num_classes = [len(d) for d in dictionaries]self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))nn.init.uniform_(self.label_embs_concat)def upgrade_state_dict_named(self, state_dict, name):"""Upgrade a (possibly old) state dict for new versions of fairseq."""super().upgrade_state_dict_named(state_dict, name)return state_dict@classmethoddef build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask):"""Build a new model instance."""model = HubertModel(cfg, task.cfg, task.dictionaries)return modeldef apply_mask(self, x, padding_mask, target_list):B, T, C = x.shapeif self.mask_prob > 0:mask_indices = compute_mask_indices((B, T),padding_mask,self.mask_prob,self.mask_length,self.mask_selection,self.mask_other,min_masks=2,no_overlap=self.no_mask_overlap,min_space=self.mask_min_space,)mask_indices = torch.from_numpy(mask_indices).to(x.device)x[mask_indices] = self.mask_embelse:mask_indices = Noneif self.mask_channel_prob > 0:mask_channel_indices = compute_mask_indices((B, C),None,self.mask_channel_prob,self.mask_channel_length,self.mask_channel_selection,self.mask_channel_other,no_overlap=self.no_mask_channel_overlap,min_space=self.mask_channel_min_space,)mask_channel_indices = (torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1))x[mask_channel_indices] = 0return x, mask_indicesdef compute_nce(self, x, pos, negs):neg_is_pos = (pos == negs).all(-1)pos = pos.unsqueeze(0)targets = torch.cat([pos, negs], dim=0)logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)logits /= self.logit_tempif neg_is_pos.any():logits[1:][neg_is_pos] = float("-inf")logits = logits.transpose(0, 1) # (num_x, num_cls+1)return logitsdef forward_features(self, source: torch.Tensor) -> torch.Tensor:if self.feature_grad_mult > 0:features = self.feature_extractor(source)if self.feature_grad_mult != 1.0:features = GradMultiply.apply(features, self.feature_grad_mult)else:with torch.no_grad():features = self.feature_extractor(source)return featuresdef forward_targets(self,features: torch.Tensor,target_list: List[torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor]:# Trim features to ensure labels exist and then get aligned labelsfeat_tsz = features.size(2)targ_tsz = min([t.size(1) for t in target_list])if self.feat2tar_ratio * feat_tsz > targ_tsz:feat_tsz = int(targ_tsz / self.feat2tar_ratio)features = features[..., :feat_tsz]target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratiotarget_list = [t[:, target_inds.long()] for t in target_list]return features, target_listdef forward_padding_mask(self,features: torch.Tensor,padding_mask: torch.Tensor,) -> torch.Tensor:extra = padding_mask.size(1) % features.size(1)if extra > 0:padding_mask = padding_mask[:, :-extra]padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)padding_mask = padding_mask.all(-1)return padding_maskdef forward(self,source: torch.Tensor,target_list: Optional[List[torch.Tensor]] = None,padding_mask: Optional[torch.Tensor] = None,mask: bool = True,features_only: bool = False,output_layer: Optional[int] = None,) -> Dict[str, torch.Tensor]:"""output layer is 1-based"""features = self.forward_features(source)if target_list is not None:features, target_list = self.forward_targets(features, target_list)features_pen = features.float().pow(2).mean()features = features.transpose(1, 2)features = self.layer_norm(features)unmasked_features = features.clone()if padding_mask is not None:padding_mask = self.forward_padding_mask(features, padding_mask)if self.post_extract_proj is not None:features = self.post_extract_proj(features)features = self.dropout_input(features)unmasked_features = self.dropout_features(unmasked_features)if mask:x, mask_indices = self.apply_mask(features, padding_mask, target_list)else:x = featuresmask_indices = None# feature: (B, T, D), float# target: (B, T), long# x: (B, T, D), float# padding_mask: (B, T), bool# mask_indices: (B, T), boolx, _ = self.encoder(x,padding_mask=padding_mask,layer=None if output_layer is None else output_layer - 1,)if features_only:return {"x": x, "padding_mask": padding_mask, "features": features}def compute_pred(proj_x, target, label_embs):# compute logits for the i-th label sety = torch.index_select(label_embs, 0, target.long())negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)if self.target_glu:y = self.target_glu(y)negs = self.target_glu(negs)# proj_x: (S, D)# y: (S, D)# negs: (Neg, S, D)return self.compute_nce(proj_x, y, negs)label_embs_list = self.label_embs_concat.split(self.num_classes, 0)if not self.skip_masked:masked_indices = torch.logical_and(~padding_mask, mask_indices)proj_x_m = self.final_proj(x[masked_indices])if self.untie_final_proj:proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)else:proj_x_m_list = [proj_x_m for _ in range(len(target_list))]logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))]else:logit_m_list = [None for _ in target_list]if not self.skip_nomask:nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)proj_x_u = self.final_proj(x[nomask_indices])if self.untie_final_proj:proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)else:proj_x_u_list = [proj_x_u for _ in range(len(target_list))]logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))]else:logit_u_list = [None for _ in target_list]result = {"logit_m_list": logit_m_list,"logit_u_list": logit_u_list,"padding_mask": padding_mask,"features_pen": features_pen,}return resultdef extract_features(self,source: torch.Tensor,padding_mask: Optional[torch.Tensor] = None,mask: bool = False,ret_conv: bool = False,output_layer: Optional[int] = None,) -> Tuple[torch.Tensor, torch.Tensor]:res = self.forward(source,padding_mask=padding_mask,mask=mask,features_only=True,output_layer=output_layer,)feature = res["features"] if ret_conv else res["x"]return feature, res["padding_mask"]def get_logits(self, net_output, is_masked=True):if is_masked:logits_list = net_output["logit_m_list"]else:logits_list = net_output["logit_u_list"]logits_list = [x.float() for x in logits_list if x is not None]return logits_listdef get_targets(self, net_output, is_masked=True):logits_list = self.get_logits(net_output, is_masked)targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]return targets_listdef get_extra_losses(self, net_output):extra_losses = []names = []if "features_pen" in net_output:extra_losses.append(net_output["features_pen"])names.append("features_pen")return extra_losses, namesdef remove_pretraining_modules(self):self.target_glu = Noneself.final_proj = None
这段代码实现了HuBERT (Hidden-Unit BERT) 模型,这是一种用于语音信号自监督预训练的经典模型。HuBERT的核心思想是通过掩码语音特征并预测其对应的离散标签(无监督聚类得到),从而学习语音的通用表示。以下是对代码的详细解析:
1. 核心依赖与背景
- 基于
fairseq
框架(Facebook的序列建模工具包)实现,用于高效训练序列模型。 - 采用自监督学习范式,通过掩码输入特征并预测标签,无需人工标注数据即可学习语音表示。
- 结合了卷积特征提取器(处理原始语音)和Transformer编码器(学习上下文依赖),类似BERT但针对语音模态。
2. HubertConfig
:模型配置参数
HubertConfig
是一个基于 dataclass
的配置类,用于集中管理 HuBERT 模型的所有超参数和结构配置。它继承自 FairseqDataclass
(fairseq 框架的基础配置类),通过字段(field
)定义了模型从特征提取、编码到掩码策略、损失计算的所有细节。这些参数决定了模型的结构、训练行为和性能,是理解 HuBERT 工作机制的关键。
核心设计目的
HubertConfig
的核心作用是将模型的所有可配置参数集中管理,包括网络结构(如卷积层、Transformer 层)、训练策略(如掩码概率)、正则化(如 dropout)等。通过这种方式,用户可以通过修改配置参数灵活调整模型,而无需修改模型代码本身。
参数详细解析
以下按功能分类解析关键参数:
1. 标签与特征对齐参数
label_rate: float = II("task.label_rate")
- 含义:标签的采样率(单位:Hz),用于将模型提取的语音特征与目标标签(如聚类结果)在时间维度上对齐。
- 细节:
II("task.label_rate")
表示该参数从任务配置(HubertPretrainingTask
)中继承,避免重复定义。
2. 特征提取器配置(卷积层)
特征提取器是 HuBERT 的前端,用于从原始语音波形中提取频谱特征,由一系列卷积层组成。相关参数如下:
-
extractor_mode: EXTRACTOR_MODE_CHOICES = "default"
- 含义:特征提取器的归一化模式,决定卷积层的归一化方式。
- 可选值:
default
(默认)和layer_norm
。default
:第一个卷积块使用组归一化(group norm),其余层无归一化;layer_norm
:每个卷积块后都使用层归一化(layer norm),通常配合normalize=True
使用。
-
conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
- 含义:卷积特征提取层的定义,以 Python 列表字符串形式表示,每个元素为
(输出维度, 卷积核大小, 步长)
。 - 默认解析:
- 第一层:输出维度 512,卷积核 10,步长 5(快速降低时间维度);
- 中间 4 层:输出维度 512,卷积核 3,步长 2(逐步压缩时间维度);
- 最后 2 层:输出维度 512,卷积核 2,步长 2(进一步压缩)。
- 作用:通过多层卷积将原始语音(如 16kHz 采样率)转换为低时间分辨率、高特征维度的表示(通常时间维度压缩 100 倍左右)。
- 含义:卷积特征提取层的定义,以 Python 列表字符串形式表示,每个元素为
-
conv_bias: bool = False
- 含义:卷积层是否使用偏置参数。默认
False
可减少参数数量,降低过拟合风险。
- 含义:卷积层是否使用偏置参数。默认
-
feature_grad_mult: float = 1.0
- 含义:特征提取器梯度的缩放因子。
- 作用:预训练时可调整该值(如设为 0.1),减弱特征提取器的更新强度,避免其过度拟合局部特征。
3. Transformer 编码器配置
Transformer 编码器是 HuBERT 的核心,用于学习特征的上下文依赖关系。相关参数如下:
-
encoder_layers: int = 12
- 含义:Transformer 编码器的层数。默认 12 层,层数越多,模型可学习的上下文复杂度越高(但计算成本也越高)。
-
encoder_embed_dim: int = 768
- 含义:Transformer 编码器的输入/输出嵌入维度。默认 768,是平衡模型容量和计算量的经验值。
-
encoder_ffn_embed_dim: int = 3072
- 含义:Transformer 前馈网络(FFN)的隐藏层维度。默认 3072(通常为
encoder_embed_dim
的 4 倍),为非线性变换提供足够容量。
- 含义:Transformer 前馈网络(FFN)的隐藏层维度。默认 3072(通常为
-
encoder_attention_heads: int = 12
- 含义:多头注意力的头数。默认 12,
encoder_embed_dim
(768)需能被头数整除(768/12=64,即每个头的维度为 64)。
- 含义:多头注意力的头数。默认 12,
-
activation_fn: str = "gelu"
- 含义:激活函数类型。默认
gelu
(高斯误差线性单元),在 Transformer 中性能优于relu
。
- 含义:激活函数类型。默认
-
layer_type: LAYER_TYPE_CHOICES = "transformer"
- 含义:编码器层的类型。默认
transformer
,也可选择conformer
(适用于更长序列的语音任务)。
- 含义:编码器层的类型。默认
-
layer_norm_first: bool = False
- 含义:Transformer 层中归一化的位置。
False
表示先做注意力/FFN,再做层归一化(常规 Transformer 做法);True
表示先归一化再做变换。
- 含义:Transformer 层中归一化的位置。
4. 正则化参数(Dropout)
一系列 dropout 参数用于防止模型过拟合,作用于不同位置:
dropout: float = 0.1
:Transformer 层的整体 dropout(如残差连接后)。attention_dropout: float = 0.1
:注意力权重的 dropout(防止注意力过度集中于某些位置)。activation_dropout: float = 0.0
:FFN 激活后的 dropout。encoder_layerdrop: float = 0.0
:随机丢弃整个 Transformer 层的概率(增强鲁棒性)。dropout_input: float = 0.0
:特征输入到 Transformer 前的 dropout。dropout_features: float = 0.0
:卷积特征提取后的 dropout。
5. 掩码策略参数(自监督核心)
HuBERT 通过掩码输入特征迫使模型学习上下文依赖,这是自监督预训练的核心。掩码分为时间维度(掩盖时间步)和通道维度(掩盖特征通道)两类:
(1) 时间维度掩码
mask_prob: float = 0.65
:随机掩码时间步的概率(默认 65%,确保足够多的位置需要预测)。mask_length: int = 10
:每个掩码块的长度(默认 10 个时间步)。mask_selection: MASKING_DISTRIBUTION_CHOICES = "static"
:掩码长度的分布策略。static
:固定长度(即mask_length
);uniform
:均匀分布在 1 到mask_length
之间;normal
:正态分布(均值为mask_length
)。
mask_other: float = 0
:辅助参数,用于复杂掩码分布(如同时掩码部分位置为随机值)。no_mask_overlap: bool = False
:是否允许掩码块重叠(False
允许重叠,增加掩码复杂度)。mask_min_space: int = 1
:若不允许重叠,掩码块之间的最小间隔(默认 1 个时间步)。
(2) 通道维度掩码
mask_channel_prob: float = 0.0
:随机掩码特征通道的概率(默认 0,即不掩码通道;可根据任务开启)。mask_channel_length: int = 10
:每个通道掩码块的长度。mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = "static"
:通道掩码长度的分布策略(同时间掩码)。- 其他参数(
mask_channel_other
、no_mask_channel_overlap
等):与时间掩码逻辑一致,仅作用于通道维度。
6. 投影与损失计算参数
final_dim: int = 0
:最终投影层的输出维度。若为 0,则等于encoder_embed_dim
(默认 768)。untie_final_proj: bool = False
:是否为每个目标标签集使用独立的投影层。False
(默认):所有标签共享一个投影层;True
:每个标签集有独立投影层(适用于多任务场景)。
logit_temp: float = 0.1
:logits 的温度系数,用于缩放余弦相似度(温度越低,概率分布越集中)。target_glu: bool = False
:是否对目标标签嵌入应用GLU
激活(Linear + GLU
),增强非线性能力。skip_masked: bool = False
:是否跳过掩码位置的损失计算(默认False
,即计算掩码位置损失,这是自监督的核心)。skip_nomask: bool = False
:是否跳过非掩码位置的损失计算(默认False
,同时利用掩码和非掩码位置优化)。
7. 位置嵌入参数
conv_pos: int = 128
:卷积位置嵌入的滤波器数量,用于学习序列的位置信息(替代传统的正弦位置嵌入)。conv_pos_groups: int = 16
:卷积位置嵌入的组数(用于组归一化)。
8. 训练优化参数
checkpoint_activations: bool = False
:是否 checkpoint 激活值(节省内存,代价是额外计算)。required_seq_len_multiple: int = 2
:输入序列长度需满足的倍数(用于对齐 Transformer 计算,避免维度不匹配)。fp16: bool = False
:是否使用 FP16 混合精度训练(加速训练,减少内存占用)。
总结
HubertConfig
是 HuBERT 模型的“控制面板”,通过这些参数可以灵活调整模型的结构(如卷积层、Transformer 层)、训练策略(如掩码概率、dropout)和优化方式(如精度、内存控制)。这些参数的设计基于语音信号的特性(如时间连续性、频谱特性)和自监督学习的需求(如掩码预测),是模型性能的关键影响因素。
例如,较高的 mask_prob
迫使模型更多依赖上下文预测,增强特征的鲁棒性;更深的 encoder_layers
可学习更长范围的上下文依赖,但会增加计算成本。通过调整这些参数,HuBERT 可适配不同的语音任务(如语音识别、情感分析)和数据规模。
3. HubertModel
:模型核心实现
HubertModel
是 HuBERT 模型的核心实现类,继承自 BaseFairseqModel
(fairseq 框架的基础模型类),封装了从原始语音输入到特征提取、掩码处理、上下文编码再到损失计算的完整逻辑。其核心目标是通过自监督学习(掩码特征预测)从无标注语音数据中学习通用语音表示。
核心设计思路
HuBERT 的核心逻辑可概括为:
- 特征提取:用卷积网络从原始语音波形中提取频谱特征;
- 掩码操作:随机掩码部分特征(时间/通道维度),迫使模型依赖上下文推断被掩盖的信息;
- 上下文编码:用 Transformer 编码器处理掩码后的特征,学习上下文依赖;
- 对比损失:通过对比学习(预测掩码位置对应的离散标签)优化模型参数。
关键组件与方法解析
1. 初始化方法(__init__
)
初始化模型的核心组件,将 HubertConfig
配置参数转化为可训练的网络层。主要包含以下部分:
def __init__(self, cfg: HubertConfig, task_cfg: HubertPretrainingConfig, dictionaries: List[Dictionary]):super().__init__()# 解析卷积特征提取层配置(如[(512,10,5), ...])feature_enc_layers = eval(cfg.conv_feature_layers)self.embed = feature_enc_layers[-1][0] # 最后一层卷积的输出维度# 1. 卷积特征提取器:从原始语音提取特征self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers,dropout=0.0,mode=cfg.extractor_mode,conv_bias=cfg.conv_bias,)# 特征与标签的时间对齐比例(用于匹配特征和目标标签的时间维度)feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) # 卷积总下采样率self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate# 2. 特征投影层:将卷积特征维度转换为Transformer输入维度self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)if self.embed != cfg.encoder_embed_dimelse None)# 3. 掩码相关参数(从配置中加载)self.mask_prob = cfg.mask_probself.mask_length = cfg.mask_length# ... 其他掩码参数(略)# 4. Dropout层:正则化self.dropout_input = nn.Dropout(cfg.dropout_input)self.dropout_features = nn.Dropout(cfg.dropout_features)# 5. 特征提取器梯度缩放因子self.feature_grad_mult = cfg.feature_grad_mult# 6. 可学习的掩码嵌入:用于替换被掩码的特征self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())# 7. Transformer编码器:学习上下文依赖self.encoder = TransformerEncoder(cfg)# 8. 输出投影与标签嵌入(用于对比损失计算)final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dimself.untie_final_proj = cfg.untie_final_projself.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * len(dictionaries)) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim)# 9. 标签嵌入矩阵(聚类中心的嵌入,用于对比学习)self.num_classes = [len(d) for d in dictionaries]self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))nn.init.uniform_(self.label_embs_concat)
核心组件说明:
feature_extractor
:卷积神经网络(CNN),将原始语音波形(如16kHz采样)转换为高维特征序列(时间维度被下采样约100倍)。post_extract_proj
:线性投影层,确保卷积特征维度与 Transformer 输入维度一致。mask_emb
:可学习的向量,用于替换被掩码的特征(类似BERT中的[MASK]
标记)。encoder
:Transformer 编码器,通过多头注意力捕捉特征序列的上下文依赖。label_embs_concat
:聚类标签的嵌入矩阵(预训练阶段通过无监督聚类生成标签),用于计算对比损失。
2. 掩码操作(apply_mask
)
自监督学习的核心步骤:随机掩码特征的时间或通道维度,迫使模型学习上下文推断能力。
def apply_mask(self, x, padding_mask, target_list):B, T, C = x.shape # B:批次,T:时间步,C:特征维度mask_indices = None# 时间维度掩码:随机选择时间步,用mask_emb替换if self.mask_prob > 0:# 生成掩码位置(布尔矩阵,True表示被掩码)mask_indices = compute_mask_indices((B, T), # 输入形状padding_mask, # padding位置(不参与掩码)self.mask_prob, # 掩码概率self.mask_length, # 掩码长度self.mask_selection, # 掩码分布策略# ... 其他掩码参数)mask_indices = torch.from_numpy(mask_indices).to(x.device)x[mask_indices] = self.mask_emb # 用mask_emb替换掩码位置# 通道维度掩码:随机选择特征通道,置零if self.mask_channel_prob > 0:mask_channel_indices = compute_mask_indices((B, C), # 通道维度形状None, # 无paddingself.mask_channel_prob, # 通道掩码概率# ... 其他通道掩码参数)mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)x[mask_channel_indices] = 0 # 通道掩码位置置零return x, mask_indices
关键逻辑:
- 时间掩码:通过
compute_mask_indices
生成符合配置(概率、长度、分布)的掩码位置,用mask_emb
替换这些位置的特征。 - 通道掩码:类似时间掩码,但作用于特征通道维度,直接将掩码位置的特征置零(破坏通道信息)。
3. 前向传播(forward
)
模型的核心流程,从输入语音到输出损失相关的logits。
def forward(self, source: torch.Tensor, target_list=None, padding_mask=None, mask=True, features_only=False, output_layer=None):# 1. 特征提取:从原始语音(source)提取卷积特征features = self.forward_features(source) # 输出:(B, C, T),C为特征维度,T为时间步# 2. 特征与目标标签对齐(调整时间维度,确保特征和标签长度匹配)if target_list is not None:features, target_list = self.forward_targets(features, target_list)# 3. 特征预处理:转置维度+层归一化features = features.transpose(1, 2) # (B, T, C)features = self.layer_norm(features)unmasked_features = features.clone() # 保存未掩码特征(用于后续计算)# 4. 处理padding掩码(标记无效语音片段,不参与计算)if padding_mask is not None:padding_mask = self.forward_padding_mask(features, padding_mask)# 5. 特征投影+Dropoutif self.post_extract_proj is not None:features = self.post_extract_proj(features) # 转换为Transformer输入维度features = self.dropout_input(features)unmasked_features = self.dropout_features(unmasked_features)# 6. 应用掩码(时间/通道)if mask:x, mask_indices = self.apply_mask(features, padding_mask, target_list)else:x = featuresmask_indices = None# 7. Transformer编码:学习上下文表示x, _ = self.encoder(x, padding_mask=padding_mask, layer=output_layer - 1 if output_layer else None)# 若仅需特征(如微调阶段),直接返回if features_only:return {"x": x, "padding_mask": padding_mask, "features": features}# 8. 计算对比损失的logits(区分掩码和非掩码位置)# ... (省略logit计算细节,见下文)return result # 包含logits、掩码信息、损失项等
步骤解析:
- 特征提取:
forward_features
调用卷积网络,将原始语音(shape(B, 1, T_wav)
)转换为特征序列(shape(B, C, T_feat)
)。 - 时间对齐:
forward_targets
调整特征长度,使其与目标标签(聚类结果)的时间维度匹配(基于feat2tar_ratio
计算对齐索引)。 - 掩码与编码:掩码后的特征通过 Transformer 编码器,输出上下文感知的表示(shape
(B, T_feat, D)
,D为编码器维度)。 - logit计算:对掩码位置(
masked_indices
)和非掩码位置(nomask_indices
)分别计算与标签嵌入的相似度(用于损失计算)。
4. 对比损失计算(compute_nce
)
通过噪声对比估计(NCE)计算损失,核心是对比模型输出与正负样本的相似度。
def compute_nce(self, x, pos, negs):# x: 模型输出特征 (S, D),S为样本数,D为特征维度# pos: 正样本(目标标签嵌入)(S, D)# negs: 负样本(其他标签嵌入)(Neg, S, D),Neg为负样本数# 标记负样本中与正样本相同的向量(需排除)neg_is_pos = (pos == negs).all(-1)pos = pos.unsqueeze(0) # (1, S, D)targets = torch.cat([pos, negs], dim=0) # (1+Neg, S, D)# 计算余弦相似度(作为logits)logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)logits /= self.logit_temp # 用温度系数缩放# 排除与正样本相同的负样本if neg_is_pos.any():logits[1:][neg_is_pos] = float("-inf")logits = logits.transpose(0, 1) # (S, 1+Neg),每行对应一个样本的正负相似度return logits
核心逻辑:
- 输入为模型输出特征
x
、正样本(目标标签的嵌入pos
)和负样本(其他标签的嵌入negs
)。 - 通过余弦相似度衡量
x
与正负样本的匹配程度,作为logits。 - 温度系数
logit_temp
控制分布的“陡峭度”(值越小,概率越集中于高相似度样本)。 - 最终logits用于计算交叉熵损失(目标是让模型对正样本的相似度最高)。
5. 辅助方法
forward_features
:封装特征提取逻辑,支持梯度缩放(通过GradMultiply
控制特征提取器的更新强度)。forward_targets
:调整特征和标签的时间维度,确保一一对应(基于卷积下采样率和标签采样率计算对齐索引)。extract_features
:用于微调阶段,提取模型中间特征(如Transformer输出)作为下游任务的输入。remove_pretraining_modules
:微调时移除预训练相关模块(如label_embs_concat
、final_proj
),减少冗余参数。
核心功能总结
HubertModel
实现了 HuBERT 从原始语音到自监督学习的完整流程,其核心创新点在于:
- 双维度掩码:同时对时间和通道维度进行掩码,强制模型学习语音的多维度特征依赖。
- 对比学习:通过预测掩码位置对应的离散标签(聚类结果),在无标注数据上学习通用语音表示。
- 灵活适配:支持通过配置参数调整网络结构(如卷积层、Transformer层数)和训练策略(如掩码概率),适配不同语音任务。
与预训练/微调的关系
- 预训练阶段:
forward
方法输出掩码/非掩码位置的logits,通过对比损失优化模型,学习语音的通用表示。 - 微调阶段:调用
extract_features
提取编码器输出特征,结合下游任务(如语音识别)的头网络进行训练,此时通常关闭掩码(mask=False
)。
通过这种“预训练+微调”的范式,HuBERT 能在少量标注数据上取得优异的语音任务性能。
4. 核心功能总结
- 自监督预训练:通过掩码语音特征并预测其对应的聚类标签,学习语音的通用表示。
- 特征提取:从原始语音中提取层次化特征(卷积特征→Transformer编码特征)。
- 灵活配置:支持多种掩码策略、网络结构参数,适配不同语音任务。
5. 应用场景
- 语音识别:预训练模型可微调到特定语言的语音识别任务,提升性能。
- 语音分类:如情感识别、说话人识别等,利用预训练特征作为输入。
- 语音生成:作为语音编码器,为生成模型提供高质量输入特征。
总之,这段代码完整实现了HuBERT模型的核心逻辑,是语音自监督学习的经典实现,通过掩码策略和对比损失让模型从无标注语音数据中学习有用的表示。