当前位置: 首页 > news >正文

audioMAE模型代码分析

VisionTransformer

audioMAE的核心是由ViT模型构成,所以先对ViT模型进行解释。
先给出完整代码:

from functools import partialimport torch
import torch.nn as nn
import numpy as np
import timm.models.vision_transformer
from timm.models.vision_transformer import PatchEmbed, Block
from util.patch_embed import PatchEmbed_new, PatchEmbed3D_newclass VisionTransformer(timm.models.vision_transformer.VisionTransformer):""" Vision Transformer with support for global average pooling"""def __init__(self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs):super(VisionTransformer, self).__init__(**kwargs)self.global_pool = global_poolif self.global_pool:norm_layer = kwargs['norm_layer']embed_dim = kwargs['embed_dim']self.fc_norm = norm_layer(embed_dim)del self.norm  # remove the original normself.mask_2d = mask_2dself.use_custom_patch = use_custom_patchnum_heads=12depth=12mlp_ratio=4def forward_features(self, x):B = x.shape[0]x = self.patch_embed(x)x = x + self.pos_embed[:, 1:, :]cls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanksx = torch.cat((cls_tokens, x), dim=1)x = self.pos_drop(x)        for blk in self.blocks:x = blk(x)if self.global_pool:x = x[:, 1:, :].mean(dim=1)  # global pool without cls tokenoutcome = self.fc_norm(x)else:x = self.norm(x)outcome = x[:, 0]return outcomedef random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoredef random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimif self.use_custom_patch:# # for AST=101 #64,101F=12 #8,12# # for ESC# T=50# F=12 # for SPC# T=12# F=12else:# ## for AS T=64F=8# ## for ESC#T=32#F=8            ## for SPC# T=8# F=8# mask Tx = x.reshape(N, T, F, D)len_keep_T = int(T * (1 - mask_t_prob))noise = torch.rand(N, T, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_keep = ids_shuffle[:, :len_keep_T]index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)#x_masked = torch.gather(x, dim=1, index=index)#x_masked = x_masked.reshape(N,len_keep_T*F,D)x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D# mask F#x = x.reshape(N, T, F, D)x = x.permute(0,2,1,3) # N T' F D => N F T' Dlen_keep_F = int(F * (1 - mask_f_prob))noise = torch.rand(N, F, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_keep = ids_shuffle[:, :len_keep_F]#index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)x_masked = torch.gather(x, dim=1, index=index)x_masked = x_masked.permute(0,2,1,3) # N F' T' D => N T' F' D #x_masked = x_masked.reshape(N,len_keep*T,D)x_masked = x_masked.reshape(N,len_keep_F*len_keep_T,D)return x_masked, None, Nonedef forward_features_mask(self, x, mask_t_prob, mask_f_prob):B = x.shape[0] #4,1,1024,128x = self.patch_embed(x) # 4, 512, 768x = x + self.pos_embed[:, 1:, :]if self.random_masking_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_t_prob)cls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)        x = self.pos_drop(x)# apply Transformer blocksfor blk in self.blocks:x = blk(x)if self.global_pool:x = x[:, 1:, :].mean(dim=1)  # global pool without cls tokenoutcome = self.fc_norm(x)else:x = self.norm(x)outcome = x[:, 0]return outcome# overwrite original timmdef forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):if mask_t_prob > 0.0 or mask_f_prob > 0.0:x = self.forward_features_mask(x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)else:x = self.forward_features(x)x = self.head(x)return xdef vit_small_patch16(**kwargs):model = VisionTransformer(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)        return modeldef vit_base_patch16(**kwargs):model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef vit_large_patch16(**kwargs):model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef vit_huge_patch14(**kwargs):model = VisionTransformer(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return model

1. 整体结构

代码继承自timm库(PyTorch Image Models)的VisionTransformer类,在此基础上扩展了功能,主要包括:

  • 支持全局平均池化(替代传统的cls token)
  • 实现了1D和2D两种掩码机制(用于自监督学习)
  • 支持自定义patch划分方式
  • 提供了不同规模的ViT模型实例化方法(small/base/large/huge)

2. 核心类:VisionTransformer

该类继承自timm.models.vision_transformer.VisionTransformer,并重写/扩展了关键方法。

2.1 构造函数 __init__

__init__ 构造函数是 VisionTransformer 类的初始化方法,它在父类 timm.models.vision_transformer.VisionTransformer 的基础上扩展了新功能,主要用于配置模型的核心参数和组件。以下是详细说明:

函数定义
def __init__(self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs):super(VisionTransformer, self).__init__(** kwargs)  # 调用父类构造函数# 自定义参数初始化self.global_pool = global_poolif self.global_pool:norm_layer = kwargs['norm_layer']embed_dim = kwargs['embed_dim']self.fc_norm = norm_layer(embed_dim)  # 全局池化后的归一化层del self.norm  # 移除父类的归一化层self.mask_2d = mask_2d  # 是否启用2D掩码机制self.use_custom_patch = use_custom_patch  # 是否使用自定义patch划分# 以下参数未实际使用(可能是调试残留或示例)num_heads=12depth=12mlp_ratio=4
核心参数解析
1. 输入参数
  • global_pool(bool,默认 False):

    • 功能:控制模型最终的特征聚合方式。
    • 若为 True:使用「全局平均池化」替代传统ViT的「cls token」作为特征聚合方式。此时会创建 self.fc_norm 作为全局池化后的归一化层。
    • 若为 False:沿用传统ViT的方式,使用cls token聚合特征。
  • mask_2d(bool,默认 True):

    • 功能:指定掩码机制的类型。
    • 若为 True:使用 random_masking_2d 方法(针对二维结构数据,如时间-频率谱图)。
    • 若为 False:使用 random_masking 方法(针对一维序列数据)。
  • use_custom_patch(bool,默认 False):

    • 功能:控制是否使用自定义的patch划分维度(影响 random_masking_2d 中的时间/频率维度设置)。
    • 若为 Truerandom_masking_2d 中会使用自定义的 T(时间维度)和 F(频率维度)值(如代码中注释的 T=101, F=12)。
    • 若为 False:使用默认的 TF 值(如 T=64, F=8)。
  • **kwargs(可变参数):

    • 功能:接收父类 VisionTransformer 所需的全部参数(如 patch_sizeembed_dimdepth 等),并传递给父类构造函数。
    • 父类关键参数包括:patch_size(patch大小)、embed_dim(嵌入维度)、depth(Transformer层数)、num_heads(注意力头数)、norm_layer(归一化层类型)等。
2. 关键操作解析
  • 调用父类构造函数

    super(VisionTransformer, self).__init__(**kwargs)
    

    这行代码初始化了父类 timm.models.vision_transformer.VisionTransformer 的所有基础组件,包括:

    • patch_embed:将输入图像分块并映射到特征维度的模块。
    • pos_embed:位置嵌入(包含cls token的位置)。
    • cls_token:分类标记(用于特征聚合)。
    • blocks:Transformer块的列表。
    • norm:父类默认的归一化层(后续可能被删除)。
    • head:最终的分类头。
  • 全局池化相关配置

    if self.global_pool:norm_layer = kwargs['norm_layer']  # 从kwargs中获取归一化层类型(如nn.LayerNorm)embed_dim = kwargs['embed_dim']    # 获取特征嵌入维度self.fc_norm = norm_layer(embed_dim)  # 创建全局池化后的归一化层
    del self.norm  # 移除父类的默认归一化层
    
    • 当启用 global_pool 时,模型不再需要父类的 norm 层(因为全局池化后使用 fc_norm 归一化),因此显式删除 self.norm
    • 若不启用 global_pool,虽然代码仍会执行 del self.norm,但后续 forward_features 中会发现 self.norm 已被删除,可能存在逻辑漏洞(需结合实际使用场景判断)。
  • 冗余参数
    代码末尾的 num_heads=12depth=12mlp_ratio=4 未被实际使用,可能是调试残留或示例代码,不影响模型功能。

总结

__init__ 构造函数的核心作用是:

  1. 复用父类 VisionTransformer 的基础组件(patch嵌入、Transformer块等)。
  2. 新增 global_poolmask_2duse_custom_patch 三个关键参数,支持:
    • 两种特征聚合方式(cls token/全局池化)。
    • 两种掩码机制(1D/2D)。
    • 自定义patch维度(适应不同数据集)。
  3. 根据 global_pool 动态调整归一化层(删除父类的 norm,新增 fc_norm)。

这些扩展使模型更灵活,尤其适合处理二维结构数据(如音频频谱图)和自监督学习任务。

2.2 特征提取:forward_features

forward_features 方法是 VisionTransformer 类的核心特征提取逻辑,负责将输入数据通过一系列处理(分块嵌入、位置编码、Transformer编码等)转换为高层特征。它是模型从原始输入到最终特征输出的关键流程,以下是详细说明:

函数定义
def forward_features(self, x):B = x.shape[0]  # 获取批次大小(batch size)x = self.patch_embed(x)  # 1. 将输入分块并嵌入到特征空间x = x + self.pos_embed[:, 1:, :]  # 2. 加入位置嵌入(排除cls token的位置)# 3. 处理cls token并与patch特征拼接cls_token = self.cls_token + self.pos_embed[:, :1, :]  # cls token加位置嵌入cls_tokens = cls_token.expand(B, -1, -1)  # 扩展到整个批次x = torch.cat((cls_tokens, x), dim=1)  # 拼接cls token和patch特征序列x = self.pos_drop(x)  # 4. 位置dropout(防止过拟合)# 5. 通过所有Transformer块进行特征编码for blk in self.blocks:x = blk(x)# 6. 特征聚合(根据配置选择cls token或全局池化)if self.global_pool:x = x[:, 1:, :].mean(dim=1)  # 全局平均池化(排除cls token)outcome = self.fc_norm(x)    # 全局池化后的归一化else:x = self.norm(x)  # 传统归一化(使用父类的norm层,若未被删除)outcome = x[:, 0]  # 取cls token作为最终特征return outcome  # 返回提取的高层特征
逐步骤解析
1. 输入与批次大小
  • x 是输入数据,形状通常为 [B, C, H, W]B:批次大小,C:通道数,H/W:高度/宽度),例如图像或频谱图数据。
  • B = x.shape[0] 获取批次大小,用于后续扩展 cls_token 到整个批次。
2. Patch嵌入(self.patch_embed(x)
  • 功能:将输入的二维数据(如图像)分割为固定大小的「patch」(块),并将每个patch映射到高维特征空间。
  • 处理过程
    • 例如,输入为 [B, 3, 224, 224](3通道图像,224×224分辨率),若 patch_size=16,则会被分割为 (224/16)×(224/16)=14×14=196 个patch。
    • 每个patch通过线性投影(或卷积)映射到 embed_dim 维度(如768),输出形状为 [B, 196, 768]B:批次,196:patch数量,768:特征维度)。
  • 对应组件self.patch_embed 是父类初始化的 PatchEmbed 实例,负责分块和嵌入。
3. 位置嵌入(x = x + self.pos_embed[:, 1:, :]
  • 功能:Transformer本身是无位置信息的,通过加入位置嵌入(positional embedding)让模型感知patch的空间位置。
  • 细节
    • self.pos_embed 是预定义的位置嵌入参数,形状为 [1, L+1, D]L:patch总数,1:cls token的位置,Dembed_dim)。
    • self.pos_embed[:, 1:, :] 表示取所有patch的位置嵌入(排除cls token的位置),与patch特征 x 相加,输出形状仍为 [B, L, D]
4. CLS Token处理与拼接
  • CLS Token:是一个可学习的向量(self.cls_token),形状为 [1, 1, D],用于聚合整个序列的特征(类似句子分类中的「[CLS]」标记)。
  • 处理过程
    • cls_token = self.cls_token + self.pos_embed[:, :1, :]:给cls token加上对应的位置嵌入(self.pos_embed 的第0位)。
    • cls_tokens = cls_token.expand(B, -1, -1):将cls token扩展到整个批次,形状变为 [B, 1, D]
    • x = torch.cat((cls_tokens, x), dim=1):将cls token与patch特征序列拼接,输出形状为 [B, L+1, D]L+1:cls token + L个patch)。
5. 位置Dropout(self.pos_drop(x)
  • 功能:对拼接后的序列(cls token + patch特征 + 位置嵌入)应用dropout,防止模型过度依赖位置信息,增强泛化能力。
  • 对应组件self.pos_drop 是父类初始化的dropout层(通常为 nn.Dropout)。
6. Transformer块编码(for blk in self.blocks: x = blk(x)
  • 功能:通过多层Transformer块对序列特征进行深度编码,捕捉patch之间的长距离依赖关系。
  • 细节
    • self.blocks 是Transformer块的列表,数量由 depth 参数指定(如12层)。
    • 每个Transformer块包含「多头自注意力」和「MLP」两个核心模块,通过残差连接和层归一化增强训练稳定性。
    • 输入输出形状保持不变([B, L+1, D]),但特征经过深层语义编码。
7. 特征聚合(最终特征提取)

根据 global_pool 参数选择不同的聚合方式:

  • 方式1:全局平均池化(global_pool=True

    • x = x[:, 1:, :].mean(dim=1):对所有patch特征(排除cls token)在序列维度(dim=1)上求平均,得到形状为 [B, D] 的全局特征。
    • outcome = self.fc_norm(x):通过 fc_norm 归一化层(在 __init__ 中创建)进行归一化,输出最终特征。
  • 方式2:CLS Token(global_pool=False

    • x = self.norm(x):对整个序列(cls token + patch)进行归一化(使用父类的 self.norm 层)。
    • outcome = x[:, 0]:取cls token(序列的第0位)作为最终特征,形状为 [B, D]
输入输出总结
  • 输入:原始数据 x,形状 [B, C, H, W]
  • 输出:高层特征 outcome,形状 [B, D]D=embed_dim)。
核心作用

forward_features 实现了ViT的核心特征提取流程:从原始输入→分块嵌入→位置编码→Transformer编码→特征聚合。它将底层的视觉/频谱信息转换为可用于分类、检索等任务的高层语义特征,是模型性能的关键所在。

该方法的灵活性体现在支持两种特征聚合方式(cls token/全局池化),可根据任务需求选择更适合的方案(例如全局池化在某些场景下可能比cls token更鲁棒)。

2.3 1D随机掩码:random_masking

random_masking 方法是实现一维随机掩码机制的核心函数,主要用于自监督学习中(如掩码自编码器MAE)对输入序列进行随机掩盖,通过让模型重建被掩盖的部分来学习更鲁棒的特征表示。以下是详细说明:

函数定义
def random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 保留的序列长度noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)  # 恢复原始顺序的索引# keep the first subsetids_keep = ids_shuffle[:, :len_keep]  # 要保留的元素索引x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))  # 保留的特征# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)  # 恢复掩码到原始序列顺序return x_masked, mask, ids_restore
核心功能

对输入的序列特征进行随机掩码,具体包括:

  1. 随机选择一部分元素保留,其余元素被掩盖(不参与后续计算)。
  2. 生成掩码矩阵(标记哪些元素被保留/掩盖)。
  3. 生成恢复索引(用于后续将掩码后的序列还原到原始顺序)。

该方法通过「基于随机噪声排序」的方式实现随机选择,确保每个样本的掩码是独立且随机的。

参数与输入输出
  • 输入

    • x:输入序列特征,形状为 [N, L, D]N:批次大小,L:序列长度,D:特征维度)。
    • mask_ratio:掩码比例(被掩盖的元素占比,如 0.75 表示掩盖75%的元素)。
  • 输出

    • x_masked:掩码后的序列特征,形状为 [N, len_keep, D]len_keep 为保留的元素数量)。
    • mask:掩码矩阵,形状为 [N, L]0 表示保留,1 表示掩盖)。
    • ids_restore:恢复索引,形状为 [N, L](用于将掩码后的序列还原到原始顺序)。
逐步骤解析
1. 计算保留长度
N, L, D = x.shape  # 解析输入序列的维度
len_keep = int(L * (1 - mask_ratio))  # 计算需要保留的元素数量
  • 例如:若序列长度 L=100mask_ratio=0.75,则 len_keep=25(保留25%的元素)。
2. 生成随机噪声并排序
noise = torch.rand(N, L, device=x.device)  # 生成[N, L]的随机噪声(取值范围[0,1))
ids_shuffle = torch.argsort(noise, dim=1)  # 对噪声按行排序,得到打乱的索引
  • 作用:通过对随机噪声排序实现「随机选择」。噪声值越小的位置,在 ids_shuffle 中索引越靠前(优先被保留)。
  • 示例:假设 noise 某行为 [0.8, 0.2, 0.5],则 ids_shuffle[1, 2, 0](对应噪声从小到大的索引)。
3. 生成恢复索引
ids_restore = torch.argsort(ids_shuffle, dim=1)  # 对打乱的索引再排序,得到恢复原始顺序的索引
  • 作用:记录掩码后的元素如何映射回原始序列位置,用于后续重建任务(如MAE中还原被掩盖的元素)。
  • 示例:若 ids_shuffle = [1, 2, 0],则 ids_restore = [2, 0, 1](表示原始索引0在打乱后位于位置2,原始索引1位于位置0, etc.)。
4. 提取保留的特征
ids_keep = ids_shuffle[:, :len_keep]  # 取前len_keep个索引(噪声最小的元素)
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
  • ids_keep.unsqueeze(-1).repeat(1, 1, D):将 [N, len_keep] 的索引扩展为 [N, len_keep, D],与输入 x 的维度匹配。
  • torch.gather(x, dim=1, index=...):沿序列维度(dim=1)提取保留的元素,得到掩码后的特征 x_masked
5. 生成掩码矩阵
mask = torch.ones([N, L], device=x.device)  # 初始化全为1的掩码(1表示掩盖)
mask[:, :len_keep] = 0  # 前len_keep个位置设为0(0表示保留)
mask = torch.gather(mask, dim=1, index=ids_restore)  # 将掩码恢复到原始序列顺序
  • 先构造与 ids_shuffle 顺序一致的掩码(前 len_keep 个为0),再通过 ids_restore 映射回原始序列的顺序,最终 mask 中每个位置的值表示该位置是否被掩盖(1=掩盖,0=保留)。
核心特点
  1. 逐样本随机:每个样本的掩码是独立生成的,避免样本间的掩码相关性。
  2. 无偏选择:通过均匀分布的随机噪声实现无偏随机选择,确保每个元素被保留的概率相同。
  3. 可恢复性:通过 ids_restore 实现掩码后序列到原始序列的映射,为后续重建任务提供支持。
应用场景

该方法主要用于自监督预训练,例如:

  • 掩码自编码器(MAE):掩盖大部分输入序列,让模型学习从少量保留元素中重建原始序列,从而学习数据的内在结构。
  • 对比学习:通过不同掩码策略生成样本的不同视图,用于对比学习任务。

在代码中,当 mask_2d=False 时,forward_features_mask 会调用该方法对序列进行一维掩码处理。

2.4 2D随机掩码:random_masking_2d

random_masking_2d 方法是针对二维结构数据(如音频频谱图、视频帧序列等具有时间-空间或时间-频率维度的数据)设计的掩码机制。与一维掩码(random_masking)对序列进行无差别随机掩盖不同,它会分别对数据的两个维度(如时间维度 T 和频率维度 F)进行掩码,更贴合二维数据的内在结构。以下是详细说明:

函数定义
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""2D: Spectrogram (masking t and f under mask_t_prob and mask_f_prob)Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dim(L = T*F,即二维结构展开后的序列长度)if self.use_custom_patch:# 自定义patch划分:根据数据集设置时间(T)和频率(F)维度大小T=101; F=12  # 例如AS(AudioSet)数据集的参数# 其他数据集可选参数:ESC数据集 T=50, F=12;SPC数据集 T=12, F=12else:# 默认patch划分T=64; F=8    # 例如默认AS数据集参数# 其他数据集可选参数:ESC T=32, F=8;SPC T=8, F=8# 第一步:时间维度(T)掩码x = x.reshape(N, T, F, D)  # 将序列重塑为二维结构 [N, T, F, D](时间×频率×特征)len_keep_T = int(T * (1 - mask_t_prob))  # 时间维度保留的长度noise = torch.rand(N, T, device=x.device)  # 生成时间维度的随机噪声 [N, T]ids_shuffle = torch.argsort(noise, dim=1)  # 对噪声排序,得到时间维度的打乱索引ids_keep = ids_shuffle[:, :len_keep_T]  # 保留前len_keep_T个时间索引# 构造索引并提取保留的时间维度特征index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)  # 扩展索引维度以匹配xx = torch.gather(x, dim=1, index=index)  # 时间维度掩码后:[N, len_keep_T, F, D]# 第二步:频率维度(F)掩码x = x.permute(0, 2, 1, 3)  # 转置为 [N, F, len_keep_T, D](方便频率维度处理)len_keep_F = int(F * (1 - mask_f_prob))  # 频率维度保留的长度noise = torch.rand(N, F, device=x.device)  # 生成频率维度的随机噪声 [N, F]ids_shuffle = torch.argsort(noise, dim=1)  # 对噪声排序,得到频率维度的打乱索引ids_keep = ids_shuffle[:, :len_keep_F]  # 保留前len_keep_F个频率索引# 构造索引并提取保留的频率维度特征index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)  # 扩展索引维度x_masked = torch.gather(x, dim=1, index=index)  # 频率维度掩码后:[N, len_keep_F, len_keep_T, D]# 恢复维度顺序并重塑为序列x_masked = x_masked.permute(0, 2, 1, 3)  # 转置回 [N, len_keep_T, len_keep_F, D]x_masked = x_masked.reshape(N, len_keep_F*len_keep_T, D)  # 展开为序列 [N, len_keep_T*len_keep_F, D]return x_masked, None, None  # 暂不返回mask和ids_restore(可扩展用于重建)
核心功能

针对二维结构数据(如频谱图的“时间-频率”维度),分两步进行掩码:

  1. 对时间维度(T)按 mask_t_prob 比例随机掩盖,保留部分时间片段。
  2. 对频率维度(F)按 mask_f_prob 比例随机掩盖,保留部分频率成分。
    最终将掩码后的二维结构重新展开为序列,用于后续模型输入。
参数与输入输出
  • 输入

    • x:输入序列特征,形状为 [N, L, D]N:批次大小,L=T*F:二维结构展开后的序列长度,D:特征维度)。
    • mask_t_prob:时间维度的掩码比例(如 0.5 表示掩盖50%的时间片段)。
    • mask_f_prob:频率维度的掩码比例(如 0.3 表示掩盖30%的频率成分)。
  • 输出

    • x_masked:掩码后的序列特征,形状为 [N, len_keep_T*len_keep_F, D]len_keep_T/len_keep_F 分别为时间/频率维度保留的长度)。
    • None, None:暂未实现掩码矩阵(mask)和恢复索引(ids_restore),可扩展用于后续重建任务。
逐步骤解析
1. 解析输入与维度设置
N, L, D = x.shape  # 解析输入维度:N=批次,L=序列长度(T*F),D=特征维度
# 根据use_custom_patch设置时间(T)和频率(F)维度的原始大小
if self.use_custom_patch:T=101; F=12  # 自定义patch划分(如AS数据集)
else:T=64; F=8    # 默认patch划分
  • 关键:L 必须等于 T*F(即输入序列是二维结构 [T, F] 展开的结果),否则重塑会出错。
  • TF 的值需与实际数据的时间-频率维度匹配(如频谱图的时间步数和频率 bins 数)。
2. 时间维度(T)掩码
x = x.reshape(N, T, F, D)  # 重塑为二维结构:[N, T, F, D](时间×频率×特征)
len_keep_T = int(T * (1 - mask_t_prob))  # 计算时间维度保留的长度
# 生成随机噪声并排序,选择保留的时间索引
noise = torch.rand(N, T, device=x.device)  # [N, T] 的随机噪声([0,1))
ids_shuffle = torch.argsort(noise, dim=1)  # 对噪声排序,得到时间维度的打乱索引
ids_keep = ids_shuffle[:, :len_keep_T]  # 保留前len_keep_T个时间索引(噪声最小的位置)
# 提取保留的时间维度特征
index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)  # 扩展索引为 [N, len_keep_T, F, D]
x = torch.gather(x, dim=1, index=index)  # 沿时间维度(dim=1)提取,结果:[N, len_keep_T, F, D]
  • 目的:随机保留部分时间片段,掩盖其余时间维度的信息(如在音频频谱图中,掩盖某些时间段的频谱)。
3. 频率维度(F)掩码
x = x.permute(0, 2, 1, 3)  # 转置为 [N, F, len_keep_T, D](将频率维度放到dim=1,方便处理)
len_keep_F = int(F * (1 - mask_f_prob))  # 计算频率维度保留的长度
# 生成随机噪声并排序,选择保留的频率索引
noise = torch.rand(N, F, device=x.device)  # [N, F] 的随机噪声([0,1))
ids_shuffle = torch.argsort(noise, dim=1)  # 对噪声排序,得到频率维度的打乱索引
ids_keep = ids_shuffle[:, :len_keep_F]  # 保留前len_keep_F个频率索引(噪声最小的位置)
# 提取保留的频率维度特征
index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)  # 扩展索引为 [N, len_keep_F, len_keep_T, D]
x_masked = torch.gather(x, dim=1, index=index)  # 沿频率维度(dim=1)提取,结果:[N, len_keep_F, len_keep_T, D]
  • 目的:在时间维度掩码的基础上,进一步随机保留部分频率成分(如在音频频谱图中,掩盖某些频率范围的信息)。
4. 重塑为序列输出
x_masked = x_masked.permute(0, 2, 1, 3)  # 转置回 [N, len_keep_T, len_keep_F, D](时间×频率×特征)
x_masked = x_masked.reshape(N, len_keep_F*len_keep_T, D)  # 展开为序列:[N, len_keep_T*len_keep_F, D]
  • 最终将二维结构重新展开为一维序列,以适应Transformer对序列输入的要求。
核心特点
  1. 维度感知掩码:区分时间和频率维度分别进行掩码,更贴合二维数据(如频谱图)的物理结构,保留了有意义的局部时空/频域相关性。
  2. 灵活配置:通过 use_custom_patch 支持不同数据集的维度参数(TF),适配多样化的二维数据。
  3. 分步掩码:先时间后频率的两步掩码策略,可独立控制两个维度的掩码强度(通过 mask_t_probmask_f_prob)。
应用场景

主要用于二维结构数据的自监督学习,例如:

  • 音频频谱图处理:对音频的梅尔频谱图(时间-频率维度)进行掩码,让模型学习从部分时频信息中重建完整频谱,提升音频分类/检索性能。
  • 视频帧序列:对视频的“时间-空间”维度进行掩码(如时间上掩盖部分帧,空间上掩盖部分区域),学习视频的时序和空间特征。

在代码中,当 mask_2d=True 时,forward_features_mask 会调用该方法对二维结构数据进行掩码处理。

2.5 带掩码的特征提取:forward_features_mask

forward_features_mask 方法是该 Vision Transformer 模型中用于带掩码的特征提取的核心逻辑,主要服务于自监督学习场景(如掩码自编码任务)。它在常规特征提取流程(forward_features)的基础上,加入了掩码操作,通过随机掩盖部分输入特征并让模型学习处理剩余特征,从而增强模型对数据本质结构的理解。以下是详细说明:

函数定义
def forward_features_mask(self, x, mask_t_prob, mask_f_prob):B = x.shape[0]  # 获取批次大小(batch size)x = self.patch_embed(x)  # 1. 将输入分块并嵌入到特征空间x = x + self.pos_embed[:, 1:, :]  # 2. 加入位置嵌入(排除cls token的位置)# 3. 根据配置选择1D或2D掩码机制if self.mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_t_prob)# 4. 处理cls token并与掩码后的patch特征拼接cls_token = self.cls_token + self.pos_embed[:, :1, :]  # cls token加位置嵌入cls_tokens = cls_token.expand(B, -1, -1)  # 扩展到整个批次x = torch.cat((cls_tokens, x), dim=1)  # 拼接cls token和掩码后的patch特征序列x = self.pos_drop(x)  # 5. 位置dropout(防止过拟合)# 6. 通过所有Transformer块进行特征编码for blk in self.blocks:x = blk(x)# 7. 特征聚合(根据配置选择cls token或全局池化)if self.global_pool:x = x[:, 1:, :].mean(dim=1)  # 全局平均池化(排除cls token)outcome = self.fc_norm(x)    # 全局池化后的归一化else:x = self.norm(x)  # 传统归一化(使用父类的norm层)outcome = x[:, 0]  # 取cls token作为最终特征return outcome  # 返回提取的高层特征
核心功能

该方法在常规特征提取流程中插入了掩码操作,具体流程为:输入数据→patch嵌入→位置编码→随机掩码→拼接cls token→Transformer编码→特征聚合。其核心目的是让模型在只有部分输入特征(未被掩码的部分)的情况下学习有效特征,从而提升模型的泛化能力和特征表示能力(尤其适用于自监督预训练)。

参数与输入输出
  • 输入

    • x:原始输入数据,形状通常为 [B, C, H, W](如图像、频谱图等)。
    • mask_t_prob:时间维度(或1D序列)的掩码比例(如 0.75 表示掩盖75%的时间/序列元素)。
    • mask_f_prob:频率维度的掩码比例(仅在 mask_2d=True 时有效,如 0.5 表示掩盖50%的频率元素)。
  • 输出

    • outcome:经过掩码处理和Transformer编码后的高层特征,形状为 [B, D]D=embed_dim,特征维度)。
逐步骤解析
1. 输入与批次大小
  • x 为原始输入数据(如 [B, 1, 1024, 128] 的音频频谱图),B = x.shape[0] 获取批次大小,用于后续扩展 cls_token
2. Patch嵌入(self.patch_embed(x)
  • forward_features 相同,将输入数据分割为patch并映射到高维特征空间。例如,输入 [B, 1, 1024, 128] 经过patch嵌入后可能变为 [B, 512, 768]512 为patch总数,768 为特征维度)。
3. 位置嵌入(x = x + self.pos_embed[:, 1:, :]
  • 为patch特征加入位置嵌入,让模型感知patch的空间/时间位置。输出形状仍为 [B, L, D]L 为patch总数)。
4. 随机掩码(核心差异点)
if self.mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
else:x, mask, ids_restore = self.random_masking(x, mask_t_prob)
  • 关键操作:根据 self.mask_2d 选择掩码方式,对patch特征进行随机掩盖:
    • mask_2d=True:调用 random_masking_2d,对二维结构(如时间-频率)的patch分别按 mask_t_prob(时间掩码比例)和 mask_f_prob(频率掩码比例)进行掩码。
    • mask_2d=False:调用 random_masking,对一维序列的patch按 mask_t_prob(整体掩码比例)进行掩码。
  • 结果x 变为掩码后的特征序列,长度缩短(仅保留未被掩盖的patch),形状为 [B, L_keep, D]L_keep 为保留的patch数量)。
5. CLS Token处理与拼接
  • forward_features 逻辑一致:为cls token加入位置嵌入,扩展到整个批次,并与掩码后的patch特征拼接,输出形状为 [B, L_keep+1, D]+1 为cls token)。
6. 位置Dropout(self.pos_drop(x)
  • 对拼接后的序列(cls token + 掩码后的patch特征)应用dropout,增强模型泛化能力。
7. Transformer块编码
  • 掩码后的序列通过所有Transformer块(self.blocks)进行深度编码,捕捉剩余patch之间的依赖关系。输入输出形状保持 [B, L_keep+1, D]
8. 特征聚合
  • forward_features 相同,根据 self.global_pool 选择聚合方式:
    • global_pool=True:对掩码后保留的patch特征(排除cls token)做全局平均池化,经 fc_norm 归一化后输出。
    • global_pool=False:对序列做归一化后,取cls token作为最终特征。
forward_features 的核心差异
对比项forward_featuresforward_features_mask
掩码操作无(使用全部patch特征)有(随机掩盖部分patch特征)
输入特征完整性完整的patch序列部分patch被掩盖的序列
主要用途有监督训练/推理(使用全部信息)自监督预训练(模拟信息缺失场景)
序列长度固定(L+1L为总patch数)可变(L_keep+1L_keep为保留数)
应用场景

该方法主要用于自监督预训练任务,例如:

  • 掩码自编码器(MAE):通过掩盖大部分输入patch,让模型学习从少量保留信息中重建原始数据,从而学习数据的内在结构。
  • 对比学习:通过不同掩码策略生成同一数据的不同视图,让模型学习视图间的一致性,提升特征判别能力。

在模型的 forward 方法中,当 mask_t_prob>0mask_f_prob>0 时,会自动调用该方法进行带掩码的特征提取。

2.6 模型前向传播:forward

forward 方法是模型的入口函数,负责协调整个模型的前向计算流程,根据输入参数决定是否启用掩码机制,并最终输出模型的预测结果。它是连接输入数据与最终输出的核心桥梁,以下是详细说明:

函数定义
def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):if mask_t_prob > 0.0 or mask_f_prob > 0.0:x = self.forward_features_mask(x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)else:x = self.forward_features(x)x = self.head(x)  # 最终分类头处理return x
核心功能

该方法根据掩码概率参数(mask_t_probmask_f_prob)决定特征提取的路径:

  • 当需要掩码(掩码概率>0)时,调用带掩码的特征提取方法(forward_features_mask)。
  • 当不需要掩码(掩码概率=0)时,调用常规特征提取方法(forward_features)。
    最终通过分类头(self.head)输出预测结果,完成整个前向传播流程。
参数解析
  • 输入参数

    • x:原始输入数据,形状通常为 [B, C, H, W]B:批次大小,C:通道数,H/W:高度/宽度),例如图像、音频频谱图等。
    • v:未使用的参数(可能是预留接口,用于扩展多模态输入等场景)。
    • mask_t_prob:时间维度(或1D序列)的掩码比例(默认 0.0,表示不掩码)。当 >0 时,启用掩码机制。
    • mask_f_prob:频率维度的掩码比例(默认 0.0),仅在 mask_2d=True 时有效,>0 时启用频率维度掩码。
  • 输出

    • 模型的预测结果,形状通常为 [B, num_classes]num_classes 为分类任务的类别数)。
逐步骤解析
1. 决定特征提取路径
if mask_t_prob > 0.0 or mask_f_prob > 0.0:# 启用掩码:调用带掩码的特征提取x = self.forward_features_mask(x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
else:# 不启用掩码:调用常规特征提取x = self.forward_features(x)
  • 逻辑:通过判断掩码概率是否大于0,选择对应的特征提取流程:
    • 当需要掩码(如自监督预训练时):使用 forward_features_mask,对输入数据进行随机掩码后再提取特征。
    • 当不需要掩码(如常规训练/推理时):使用 forward_features,直接对完整输入提取特征。
  • 特征提取结果:两种路径均输出形状为 [B, D] 的高层特征(Dembed_dim,特征维度)。
2. 分类头处理(self.head(x)
x = self.head(x)
  • 功能:将高层特征映射到任务输出空间(如分类任务的类别概率)。
  • 细节self.head 是在父类 VisionTransformer 中初始化的分类头,通常为线性层(nn.Linear),输入维度为 embed_dim,输出维度为任务的类别数(num_classes)。
  • 示例:若 embed_dim=768,分类任务有1000个类别,则 self.headnn.Linear(768, 1000),输出形状为 [B, 1000]
3. 返回预测结果
return x
  • 输出最终的预测结果(如分类概率分布),供后续计算损失(训练时)或直接使用(推理时)。
关键特性
  1. 双路径设计:通过掩码概率参数无缝切换「带掩码」和「无掩码」两种模式,兼顾自监督预训练(需要掩码)和常规任务(不需要掩码)的需求。
  2. 兼容性:保留了父类ViT的基本接口,同时扩展了掩码相关参数,不破坏原有使用习惯。
  3. 灵活性:支持独立控制时间和频率维度的掩码比例(mask_t_probmask_f_prob),适配不同的自监督训练策略。
应用场景
  • 自监督预训练:当设置 mask_t_prob>0mask_f_prob>0 时,模型进入掩码模式,用于训练模型从部分信息中学习数据结构(如MAE任务)。
  • 有监督训练/推理:当 mask_t_prob=0mask_f_prob=0 时,模型使用完整输入进行特征提取,用于常规的分类、回归等任务。
总结

forward 方法是模型的总调度器,通过简单的条件判断协调两种特征提取路径,最终输出预测结果。它的设计既满足了自监督学习对掩码机制的需求,又保持了常规任务的兼容性,体现了模型在不同训练阶段和应用场景下的灵活性。

3. 模型实例化函数

定义了不同规模的ViT模型(参数与原始ViT一致):

  • vit_small_patch16:小模型,patch大小16×16,嵌入维度384,12层Transformer,6个注意力头。
  • vit_base_patch16:基础模型,嵌入维度768,12层,12个注意力头。
  • vit_large_patch16:大模型,嵌入维度1024,24层,16个注意力头。
  • vit_huge_patch14:超大模型,patch大小14×14,嵌入维度1280,32层,16个注意力头。

4. 核心改进与应用场景

  • 灵活的特征聚合:支持cls token或全局平均池化,适应不同任务需求。
  • 二维掩码机制:针对时间-频率等二维结构数据(如音频、视频)优化,更符合数据的空间/时间相关性。
  • 自监督学习支持:通过随机掩码实现类似MAE(Masked Autoencoder)的预训练任务,提升模型特征学习能力。

该代码可用于图像分类、音频频谱分析等任务,尤其适合需要自监督预训练的场景。

audioMAE

现在介绍audioMAE,先给出完整的代码。

from functools import partial
from json import encoderimport torch
import torch.nn as nn#from timm.models.vision_transformer import PatchEmbed, Block
from timm.models.vision_transformer import Block
from util.pos_embed import get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_flexible, get_1d_sincos_pos_embed_from_grid
from util.misc import concat_all_gather
from util.patch_embed import PatchEmbed_new, PatchEmbed_org
from timm.models.swin_transformer import SwinTransformerBlockclass MaskedAutoencoderViT(nn.Module):""" Masked Autoencoder with VisionTransformer backbone"""def __init__(self, img_size=224, patch_size=16, stride=10, in_chans=3,embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, audio_exp=False, alpha=0.0, temperature=.2, mode=0, contextual_depth=8,use_custom_patch=False, split_pos=False, pos_trainable=False, use_nce=False, beta=4.0, decoder_mode=0,mask_t_prob=0.6, mask_f_prob=0.5, mask_2d=False,epoch=0, no_shift=False,):super().__init__()self.audio_exp=audio_expself.embed_dim = embed_dimself.decoder_embed_dim = decoder_embed_dim# --------------------------------------------------------------------------# MAE encoder specificsif use_custom_patch:print(f'Use custom patch_emb with patch size: {patch_size}, stride: {stride}')self.patch_embed = PatchEmbed_new(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, stride=stride)else:self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)self.use_custom_patch = use_custom_patchnum_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))#self.split_pos = split_pos # not usefulself.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable)  # fixed sin-cos embeddingself.encoder_depth = depthself.contextual_depth = contextual_depthself.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)# --------------------------------------------------------------------------# MAE decoder specificsself.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=pos_trainable)  # fixed sin-cos embeddingself.no_shift=no_shiftself.decoder_mode = decoder_modeif self.use_custom_patch: # overlapped patches as in AST. Similar performance yet compute heavywindow_size= (6,6)feat_size = (102,12)else:window_size= (4,4)feat_size = (64,8)                if self.decoder_mode == 1:decoder_modules = []for index in range(16):if self.no_shift:shift_size = (0,0)else:if (index % 2) == 0:shift_size = (0,0)else:shift_size = (2,0)#shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])decoder_modules.append(SwinTransformerBlock(dim=decoder_embed_dim,num_heads=16,feat_size=feat_size,window_size=window_size,shift_size=shift_size,mlp_ratio=mlp_ratio,drop=0.0,drop_attn=0.0,drop_path=0.0,extra_norm=False,sequential_attn=False,norm_layer=norm_layer, #nn.LayerNorm,))self.decoder_blocks = nn.ModuleList(decoder_modules)        else:# Transfomerself.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(decoder_depth)])self.decoder_norm = norm_layer(decoder_embed_dim)self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch# --------------------------------------------------------------------------self.norm_pix_loss = norm_pix_lossself.patch_size=patch_sizeself.stride=stride# audio expsself.alpha = alphaself.T = temperatureself.mode = modeself.use_nce = use_nceself.beta = betaself.log_softmax=nn.LogSoftmax(dim=-1)self.mask_t_prob=mask_t_probself.mask_f_prob=mask_f_probself.mask_2d=mask_2dself.epoch = epochself.initialize_weights()def initialize_weights(self):# initialization# initialize (and freeze) pos_embed by sin-cos embeddingif self.audio_exp:pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True)    else:pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))if self.audio_exp:   decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True)else:decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))# initialize patch_embed like nn.Linear (instead of nn.Conv2d)w = self.patch_embed.proj.weight.datatorch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)torch.nn.init.normal_(self.cls_token, std=.02)torch.nn.init.normal_(self.mask_token, std=.02)# initialize nn.Linear and nn.LayerNormself.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):# we use xavier_uniform following official JAX ViT:torch.nn.init.xavier_uniform_(m.weight)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def patchify(self, imgs):"""imgs: (N, 3, H, W)x: (N, L, patch_size**2 *3)L = (H/p)*(W/p)"""p = self.patch_embed.patch_size[0]#assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0if self.audio_exp:if self.use_custom_patch: # overlapped patchh,w = self.patch_embed.patch_hw# todo: fixed h/w patch size and stride size. Make hw custom in the futurex = imgs.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride) # n,1,H,W -> n,1,h,w,p,px = x.reshape(shape=(imgs.shape[0], h*w, p**2 * 1))#x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))#x = torch.einsum('nchpwq->nhwpqc', x)#x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))else:h = imgs.shape[2] // pw = imgs.shape[3] // p#h,w = self.patch_embed.patch_hwx = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))else:h = w = imgs.shape[2] // px = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))return xdef unpatchify(self, x):"""x: (N, L, patch_size**2 *3)specs: (N, 1, H, W)"""p = self.patch_embed.patch_size[0]    h = 1024//pw = 128//px = x.reshape(shape=(x.shape[0], h, w, p, p, 1))x = torch.einsum('nhwpqc->nchpwq', x)specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))return specsdef random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoredef random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimif self.use_custom_patch: # overlapped patchT=101F=12else:            T=64F=8#x = x.reshape(N, T, F, D)len_keep_t = int(T * (1 - mask_t_prob))len_keep_f = int(F * (1 - mask_f_prob))# noise for mask in timenoise_t = torch.rand(N, T, device=x.device)  # noise in [0, 1]# sort noise for each sample aling timeids_shuffle_t = torch.argsort(noise_t, dim=1) # ascend: small is keep, large is removeids_restore_t = torch.argsort(ids_shuffle_t, dim=1) ids_keep_t = ids_shuffle_t[:,:len_keep_t]# noise mask in freqnoise_f = torch.rand(N, F, device=x.device)  # noise in [0, 1]ids_shuffle_f = torch.argsort(noise_f, dim=1) # ascend: small is keep, large is removeids_restore_f = torch.argsort(ids_shuffle_f, dim=1) ids_keep_f = ids_shuffle_f[:,:len_keep_f] ## generate the binary mask: 0 is keep, 1 is remove# mask in freqmask_f = torch.ones(N, F, device=x.device)mask_f[:,:len_keep_f] = 0mask_f = torch.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1,T,1) # N,T,F# mask in timemask_t = torch.ones(N, T, device=x.device)mask_t[:,:len_keep_t] = 0mask_t = torch.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1,F,1).permute(0,2,1) # N,T,Fmask = 1-(1-mask_t)*(1-mask_f) # N, T, F# get masked xid2res=torch.Tensor(list(range(N*T*F))).reshape(N,T,F).to(x.device)id2res = id2res + 999*mask # add a large value for masked elementsid2res2 = torch.argsort(id2res.flatten(start_dim=1))ids_keep=id2res2.flatten(start_dim=1)[:,:len_keep_f*len_keep_t]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))ids_restore = torch.argsort(id2res2.flatten(start_dim=1))mask = mask.flatten(start_dim=1)return x_masked, mask, ids_restoredef forward_encoder(self, x, mask_ratio, mask_2d=False):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratioif mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blocksfor blk in self.blocks:x = blk(x)x = self.norm(x)#emb = self.encoder_emb(x)return x, mask, ids_restore, Nonedef forward_encoder_no_mask(self, x):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratio#x, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blockscontextual_embs=[]for n, blk in enumerate(self.blocks):x = blk(x)if n > self.contextual_depth:contextual_embs.append(self.norm(x))#x = self.norm(x)contextual_emb = torch.stack(contextual_embs,dim=0).mean(dim=0)return contextual_embdef forward_decoder(self, x, ids_restore):# embed tokensx = self.decoder_embed(x)# append mask tokens to sequencemask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls tokenx_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshufflex = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token# add pos embedx = x + self.decoder_pos_embedif self.decoder_mode != 0:B,L,D=x.shapex = x[:,1:,:]if self.use_custom_patch:x = x.reshape(B,101,12,D)x = torch.cat([x,x[:,-1,:].unsqueeze(1)],dim=1) # hackx = x.reshape(B,1224,D)if self.decoder_mode > 3: # mvitx = self.decoder_blocks(x)else:# apply Transformer blocksfor blk in self.decoder_blocks:x = blk(x)x = self.decoder_norm(x)# predictor projectionpred = self.decoder_pred(x)# remove cls tokenif self.decoder_mode != 0:if self.use_custom_patch:pred = pred.reshape(B,102,12,256)pred = pred[:,:101,:,:]pred = pred.reshape(B,1212,256)else:pred = predelse:pred = pred[:, 1:, :]return pred, None, None #emb, emb_pixeldef forward_loss(self, imgs, pred, mask, norm_pix_loss=False):"""imgs: [N, 3, H, W]pred: [N, L, p*p*3]mask: [N, L], 0 is keep, 1 is remove, """target = self.patchify(imgs)if norm_pix_loss:mean = target.mean(dim=-1, keepdim=True)var = target.var(dim=-1, keepdim=True)target = (target - mean) / (var + 1.e-6)**.5loss = (pred - target) ** 2loss = loss.mean(dim=-1)  # [N, L], mean loss per patchloss = (loss * mask).sum() / mask.sum()  # mean loss on removed patchesreturn loss      def forward(self, imgs, mask_ratio=0.8):emb_enc, mask, ids_restore, _ = self.forward_encoder(imgs, mask_ratio, mask_2d=self.mask_2d)pred, _, _ = self.forward_decoder(emb_enc, ids_restore)  # [N, L, p*p*3]loss_recon = self.forward_loss(imgs, pred, mask, norm_pix_loss=self.norm_pix_loss)loss_contrastive = torch.FloatTensor([0.0]).cuda()return loss_recon, pred, mask, loss_contrastivedef mae_vit_small_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=384, depth=12, num_heads=6,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_base_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=768, depth=12, num_heads=12,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_large_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_huge_patch14_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=14, embed_dim=1280, depth=32, num_heads=16,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return model# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks

这段代码实现了一个基于Vision Transformer(ViT)的掩码自编码器(Masked Autoencoder, MAE),核心用于自监督学习——通过随机掩盖输入数据的部分区域,让模型学习从剩余区域重建完整数据,从而提取通用的特征表示。该模型特别适配了音频频谱图(如梅尔频谱)的处理(通过audio_exp、2D掩码等参数),同时保留了对图像数据的兼容性。

一、整体框架与核心设计

MAE的核心逻辑是“编码-解码”结构:

  1. 编码器(Encoder):对输入数据分块(Patch),随机掩盖部分Patch后,通过Transformer提取特征。
  2. 解码器(Decoder):接收编码器输出的“未掩盖Patch特征”,拼接“掩码Token”,通过Transformer(或Swin Transformer)重建被掩盖的Patch。
  3. 损失计算:仅对被掩盖的Patch计算重建损失,迫使模型学习数据的内在结构。

该代码在标准MAE基础上扩展了音频适配(如重叠Patch、2D掩码)、灵活解码器(支持ViT/Swin解码器)、可配置位置嵌入等功能,下面分模块详细解析。

二、核心类:MaskedAutoencoderViT

2.1 构造函数 __init__:初始化编码器/解码器组件

__init__ 是模型的“骨架搭建”部分,定义了编码器、解码器的核心模块及超参数,参数多达20+,需按功能分组理解:

1. 基础配置参数
参数功能说明
img_size输入数据尺寸(如音频频谱图 1024×128、图像 224×224
patch_sizePatch(数据块)的大小(如 16 表示 16×16 的Patch)
stridePatch划分的步长(仅自定义Patch时生效,用于生成重叠Patch
in_chans输入通道数(图像3通道,音频频谱图1通道)
embed_dim编码器特征维度(如ViT-Base为768)
depth编码器Transformer块数量
num_heads编码器多头注意力头数
decoder_embed_dim解码器特征维度(通常小于编码器,如512,降低计算量)
decoder_depth解码器Transformer块数量
decoder_num_heads解码器多头注意力头数
mlp_ratioTransformer块中MLP层的扩张比例(如4表示隐藏层维度是输入的4倍)
norm_layer归一化层类型(默认nn.LayerNorm
norm_pix_loss是否对像素进行归一化后计算损失(稳定训练)
2. 音频适配参数
参数功能说明
audio_exp是否为音频实验(True时处理1通道频谱图,False处理3通道图像)
use_custom_patch是否使用自定义Patch嵌入(PatchEmbed_new),支持重叠Patch(适配音频)
mask_t_prob/mask_f_prob2D掩码中“时间维度”和“频率维度”的掩码比例(仅音频频谱图生效)
mask_2d是否启用2D掩码(区分时间/频率维度,而非1D无差别掩码)
3. 训练与结构适配参数
参数功能说明
pos_trainable位置嵌入(Positional Embedding)是否可训练(默认固定为sin-cos嵌入)
decoder_mode解码器类型(0=ViT解码器,1=Swin Transformer解码器,适配二维结构)
no_shiftSwin解码器中是否禁用“移位窗口”(避免边界效应)
use_nce是否启用对比损失(预留接口,当前未实现)
alpha/temperature对比损失相关超参数(预留)
4. 核心组件初始化

MaskedAutoencoderViT 类的构造函数(__init__)中,核心组件的初始化围绕编码器(Encoder)解码器(Decoder) 两大模块展开,每个模块包含多个关键组件,共同支撑掩码自编码器的“编码-解码-重建”流程。以下是核心组件的详细初始化描述:

一、编码器(Encoder)组件初始化

编码器的作用是:将输入数据分割为Patch、添加位置信息、随机掩码部分Patch后,通过Transformer提取高层特征。其核心组件包括:

1. Patch嵌入层(self.patch_embed
  • 功能:将输入的二维数据(图像或音频频谱图)分割为固定大小的Patch,并通过线性映射将每个Patch转换为高维特征(维度为embed_dim)。
  • 初始化逻辑
    if use_custom_patch:self.patch_embed = PatchEmbed_new(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, stride=stride)
    else:self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
    
    • PatchEmbed_org:普通Patch嵌入(无重叠),适用于图像等结构化数据。通过卷积或线性层将[H, W]的输入按patch_size分割为(H/patch_size)×(W/patch_size)个Patch,映射到embed_dim维度。
    • PatchEmbed_new:自定义Patch嵌入(支持重叠),适用于音频频谱图等需要保留时序/频域连续性的数据。通过stride参数控制滑动步长(如patch_size=16stride=10),生成重叠Patch,增强局部相关性。
  • 关键参数img_size(输入尺寸)、patch_size(Patch大小)、stride(滑动步长,仅自定义模式有效)、in_chans(输入通道数)、embed_dim(输出特征维度)。
2. 分类标记(self.cls_token
  • 功能:一个可学习的向量,用于聚合编码器的全局特征(类似ViT中的[CLS]标记),辅助解码器重建全局信息。
  • 初始化逻辑
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    
    • 形状为[1, 1, embed_dim]( batch维度=1,序列长度=1,特征维度=embed_dim)。
    • 初始化为全零向量,后续通过训练学习(在initialize_weights中用正态分布细化初始化:torch.nn.init.normal_(self.cls_token, std=.02))。
3. 编码器位置嵌入(self.pos_embed
  • 功能:为Patch添加位置信息(Transformer本身无位置感知能力),使模型理解Patch的空间/时序位置关系。
  • 初始化逻辑
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable)
    
    • 形状[1, num_patches + 1, embed_dim],其中num_patches是总Patch数(由self.patch_embed.num_patches获取),+1cls_token预留位置。
    • 可训练性:由pos_trainable控制,True则位置嵌入可通过训练更新;False则使用固定的sin-cos嵌入(在initialize_weights中初始化)。
    • 位置嵌入生成:音频数据用get_2d_sincos_pos_embed_flexible,图像用get_2d_sincos_pos_embed,确保与Patch的二维结构匹配。
4. Transformer编码器块(self.blocks
  • 功能:通过多层Transformer块对Patch特征进行深度编码,捕捉Patch间的长距离依赖关系。
  • 初始化逻辑
    self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(depth)
    ])
    
    • 组成nn.ModuleList包含depthBlock(来自timm.models.vision_transformer.Block),每个Block由“多头自注意力(Multi-Head Attention)”和“MLP”组成,通过残差连接和层归一化增强训练稳定性。
    • 关键参数embed_dim(输入特征维度)、num_heads(注意力头数)、mlp_ratio(MLP层扩张比例,如4表示隐藏层维度是输入的4倍)、norm_layer(归一化层类型)。
5. 编码器归一化层(self.norm
  • 功能:对编码器最后一层的输出进行归一化,稳定特征分布。
  • 初始化逻辑
    self.norm = norm_layer(embed_dim)
    
    • 使用norm_layer(默认nn.LayerNorm),输入维度为embed_dim,在_init_weights中被初始化为权重=1、偏置=0。
二、解码器(Decoder)组件初始化

解码器的作用是:接收编码器输出的“未掩码Patch特征”,拼接“掩码Token”后重建被掩码的Patch。其核心组件包括:

1. 解码器嵌入层(self.decoder_embed
  • 功能:将编码器输出的特征(embed_dim维度)线性投影到解码器的特征维度(decoder_embed_dim),匹配解码器输入要求。
  • 初始化逻辑
    self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
    
    • 线性层(nn.Linear),输入维度embed_dim,输出维度decoder_embed_dim,带偏置。
    • _init_weights中用Xavier均匀初始化权重,偏置初始化为0。
2. 掩码Token(self.mask_token
  • 功能:一个可学习的向量,用于填充被掩码的Patch位置(编码器未处理这些Patch,解码器需通过该Token重建)。
  • 初始化逻辑
    self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
    
    • 形状为[1, 1, decoder_embed_dim](与单个Patch的特征维度匹配)。
    • 初始化为全零向量,训练中学习最优值(在initialize_weights中用正态分布细化:torch.nn.init.normal_(self.mask_token, std=.02))。
3. 解码器位置嵌入(self.decoder_pos_embed
  • 功能:为解码器的Patch和cls_token添加位置信息,确保重建时Patch的位置对齐。
  • 初始化逻辑
    self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=pos_trainable)
    
    • 形状[1, num_patches + 1, decoder_embed_dim],与编码器位置嵌入结构一致,但维度适配解码器(decoder_embed_dim)。
    • 初始化方式:同编码器位置嵌入(sin-cos嵌入或可训练),在initialize_weights中通过get_2d_sincos_pos_embedget_2d_sincos_pos_embed_flexible生成。
4. 解码器Transformer块(self.decoder_blocks
  • 功能:对解码器输入(未掩码Patch特征+掩码Token)进行编码,学习重建被掩码Patch的特征。
  • 初始化逻辑:支持两种类型的解码器(由decoder_mode控制):
    if self.decoder_mode == 1:# Swin Transformer解码器(适配二维结构)decoder_modules = [SwinTransformerBlock(dim=decoder_embed_dim, num_heads=16, feat_size=feat_size,window_size=window_size, shift_size=shift_size, mlp_ratio=mlp_ratio,norm_layer=norm_layer) for index in range(16)]self.decoder_blocks = nn.ModuleList(decoder_modules)
    else:# 普通ViT解码器self.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for _ in range(decoder_depth)])
    
    • 普通ViT解码器(decoder_mode≠1decoder_depth个标准Block,参数为decoder_embed_dim(特征维度)、decoder_num_heads(注意力头数)等,与编码器结构一致。
    • Swin Transformer解码器(decoder_mode=1:16个SwinTransformerBlock(来自timm.models.swin_transformer),适配二维数据(如音频频谱图的时间-频率维度):
      • window_size:注意力窗口大小(如(4,4)),控制局部注意力范围。
      • shift_size:移位窗口步长(偶数层不移位,奇数层移位2),避免重复计算并增强全局感受野。
      • feat_size:特征图尺寸(如(102,12)),匹配输入Patch的二维结构。
5. 解码器归一化层(self.decoder_norm
  • 功能:对解码器最后一层的输出进行归一化,稳定重建特征的分布。
  • 初始化逻辑
    self.decoder_norm = norm_layer(decoder_embed_dim)
    
    • 使用norm_layer,输入维度为decoder_embed_dim,初始化方式同编码器归一化层。
6. 解码器预测层(self.decoder_pred
  • 功能:将解码器输出的特征(decoder_embed_dim维度)线性投影到“单个Patch的原始像素维度”,实现Patch重建。
  • 初始化逻辑
    self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
    
    • 输出维度为patch_size² × in_chans(单个Patch的像素总数,如16×16×1=256 for 音频单通道)。
    • _init_weights中用Xavier均匀初始化权重,偏置初始化为0。
三、核心组件的协同关系

编码器与解码器的组件通过“特征传递”和“位置对齐”协同工作:

  1. Patch嵌入与位置嵌入patch_embed输出的Patch特征与pos_embed相加,赋予位置信息。
  2. 掩码与Token拼接:编码器掩码后保留的Patch特征与cls_token拼接,输入blocks编码;解码器通过mask_token填充被掩码位置,与编码器输出对齐。
  3. 特征维度匹配decoder_embed将编码器特征投影到decoder_embed_dim,确保与解码器块输入维度一致。
  4. 重建映射decoder_pred将解码器输出映射到Patch像素维度,最终通过unpatchify还原为原始数据形状。
总结

构造函数通过初始化编码器的“Patch嵌入-位置编码-Transformer编码”组件和解码器的“特征投影-掩码填充-Transformer解码-重建预测”组件,搭建了MAE的完整架构。这些组件的设计充分考虑了音频/图像的多模态适配(重叠Patch、2D掩码)、特征维度匹配(embed_dimdecoder_embed_dim)和位置信息一致性(编码器/解码器位置嵌入),为自监督学习中的“掩码-重建”任务提供了基础。

2.2 权重初始化:initialize_weights

initialize_weights 方法是 MaskedAutoencoderViT 模型中权重初始化的核心逻辑,负责为模型所有可学习参数(如位置嵌入、Patch嵌入、Token向量、线性层等)设置合理的初始值。合理的权重初始化是模型训练稳定收敛的前提,尤其对于Transformer这类深度模型,不当的初始化可能导致梯度消失/爆炸或训练停滞。以下是该方法的详细解析:

方法定义与整体作用
def initialize_weights(self):# 初始化(并冻结)位置嵌入为sin-cos嵌入if self.audio_exp:pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True)    else:pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))# 初始化解码器位置嵌入if self.audio_exp:   decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True)else:decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))# 初始化Patch嵌入层(类似nn.Linear而非nn.Conv2d)w = self.patch_embed.proj.weight.datatorch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))# 初始化cls_token和mask_token(用timm的trunc_normal_等效实现)torch.nn.init.normal_(self.cls_token, std=.02)torch.nn.init.normal_(self.mask_token, std=.02)# 初始化所有nn.Linear和nn.LayerNorm层self.apply(self._init_weights)

核心作用:为模型的关键组件(位置嵌入、Patch嵌入、可学习Token、线性层、归一化层)设置初始权重,确保训练开始时参数分布合理,梯度流动稳定。

逐步骤解析
1. 编码器位置嵌入(self.pos_embed)的初始化

位置嵌入(Positional Embedding)是Transformer理解序列位置关系的核心,MAE中默认使用固定的正弦余弦嵌入(而非随机初始化),原因是:

  • 正弦余弦嵌入具有天然的位置连续性(距离近的位置嵌入相似),更适合捕捉空间/时序关系;
  • 固定嵌入可减少训练参数,避免过拟合,尤其在自监督预训练阶段。

初始化逻辑

if self.audio_exp:# 音频场景:使用灵活的2D正弦余弦嵌入(适配非正方形Patch)pos_embed = get_2d_sincos_pos_embed_flexible(embed_dim=self.pos_embed.shape[-1],  # 嵌入维度(如768)grid_size=self.patch_embed.patch_hw,  # Patch的二维尺寸(如T=101, F=12)cls_token=True  # 预留cls_token的位置(嵌入长度+1))
else:# 图像场景:使用标准2D正弦余弦嵌入(正方形Patch)pos_embed = get_2d_sincos_pos_embed(embed_dim=self.pos_embed.shape[-1],grid_size=int(self.patch_embed.num_patches** .5),  # 正方形网格尺寸(如14=224/16)cls_token=True)
# 将生成的numpy数组转换为Tensor,复制到pos_embed参数中
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
  • get_2d_sincos_pos_embed:生成正方形网格的位置嵌入,适用于图像(如14×14的Patch网格)。
  • get_2d_sincos_pos_embed_flexible:生成非正方形网格的位置嵌入,适用于音频频谱图(如101×12的时间-频率Patch网格)。
  • 最终嵌入形状为 [1, num_patches+1, embed_dim](+1为cls_token预留位置),通过data.copy_赋值,若pos_trainable=False(默认),则后续训练中不更新。
2. 解码器位置嵌入(self.decoder_pos_embed)的初始化

解码器位置嵌入的作用与编码器类似,用于对齐解码器中Patch的位置关系,确保重建时空间/时序一致性。

初始化逻辑:与编码器位置嵌入完全对称,仅嵌入维度为decoder_embed_dim(解码器特征维度),代码逻辑相同:

if self.audio_exp:   decoder_pos_embed = get_2d_sincos_pos_embed_flexible(...)
else:decoder_pos_embed = get_2d_sincos_pos_embed(...)
self.decoder_pos_embed.data.copy_(...)
3. Patch嵌入层(self.patch_embed.proj)的初始化

Patch嵌入层(proj)负责将输入数据的Patch映射到高维特征空间(如patch_embed.proj是卷积层或线性层),其初始化需确保映射前后的信号方差一致。

初始化逻辑

w = self.patch_embed.proj.weight.data  # 获取卷积/线性层的权重
# 将权重重塑为[输出维度, 输入维度](忽略空间维度),用Xavier均匀初始化
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
  • Xavier均匀初始化:适用于线性映射(如Patch嵌入本质是线性投影),通过使权重分布的方差与输入/输出维度适配,确保前向传播和反向传播中信号的方差一致,避免梯度消失/爆炸。
  • 无论proj是卷积层(nn.Conv2d)还是线性层(nn.Linear),均视为“输入Patch的像素→输出特征”的线性映射,因此重塑后按线性层方式初始化。
4. 可学习Token(cls_tokenmask_token)的初始化

cls_token(分类Token)和mask_token(掩码Token)是可学习的向量,需初始化为较小的随机值,避免初始时对特征产生过大影响。

初始化逻辑

# 用标准差0.02的正态分布初始化(等效于timm的trunc_normal_,截断值较大时近似正态分布)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
  • cls_token用于聚合编码器全局特征,mask_token用于填充被掩码的Patch位置,初始值过小会导致初始影响弱,过大则可能主导特征,0.02的标准差是视觉Transformer中常用的经验值。
5. 线性层与归一化层的初始化(self._init_weights

通过self.apply(self._init_weights)对模型中所有子模块递归应用初始化,重点处理nn.Linearnn.LayerNorm

def _init_weights(self, m):if isinstance(m, nn.Linear):# 线性层:Xavier均匀初始化权重,偏置设为0torch.nn.init.xavier_uniform_(m.weight)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):# 归一化层:权重设为1,偏置设为0(初始不改变输入分布)nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)
  • 线性层(nn.Linear:如decoder_embed(编码器→解码器投影)、decoder_pred(解码器→Patch重建)等,Xavier初始化确保特征映射的稳定性,偏置初始化为0避免引入额外偏移。
  • 归一化层(nn.LayerNorm:如self.norm(编码器输出归一化)、self.decoder_norm(解码器输出归一化)等,初始化为“权重=1,偏置=0”,确保初始时不改变输入特征的分布(仅在训练中学习调整)。
关键设计原则
  1. 位置嵌入的固定性:默认使用sin-cos嵌入而非可学习嵌入,利用其天然的位置连续性,尤其适合音频/图像的空间/时序结构。
  2. 线性映射的一致性:Patch嵌入和线性层均使用Xavier初始化,确保特征在映射过程中方差稳定,避免梯度问题。
  3. 可学习Token的弱初始化cls_tokenmask_token初始为小随机值,让模型在训练中自主学习最优表示,避免初始主导特征。
  4. 归一化层的中性初始化:初始不改变输入分布,确保训练初期特征的自然演化。
总结

initialize_weights 方法通过针对性的初始化策略,为MAE的各核心组件(位置嵌入、Patch嵌入、可学习Token、线性层、归一化层)设置了合理的初始权重。这些策略兼顾了Transformer的结构特性(对位置敏感、依赖线性映射)和自监督学习的需求(稳定训练、捕捉数据内在结构),为模型后续的掩码重建任务奠定了基础。

2.3 Patch处理:patchifyunpatchify

patchifyunpatchify是掩码自编码器(MAE)中连接“原始输入数据”与“模型处理的Patch序列”的核心方法:

  • patchify:将原始二维数据(图像或音频频谱图)分割为固定大小的Patch,转换为模型可处理的序列格式。
  • unpatchify:将模型输出的Patch序列还原为原始数据形状,用于计算重建损失(对比原始数据与重建结果)。

这两个方法是MAE“掩码-重建”逻辑的基础,确保输入数据能被模型处理,且输出能被还原为原始格式进行损失计算。以下是详细解析:

一、patchify:将原始数据分割为Patch序列
功能

将形状为 [N, C, H, W] 的原始数据(N:批次大小,C:通道数,H/W:高度/宽度)分割为 L 个Patch(L = h × wh/w 为Patch的行数/列数),输出形状为 [N, L, p²×C] 的序列(p 为Patch大小,p²×C 为单个Patch的像素/特征维度)。

方法定义
def patchify(self, imgs):"""imgs: (N, C, H, W)  # 原始输入数据x: (N, L, patch_size**2 * C)  # 输出的Patch序列,L = h*w"""p = self.patch_embed.patch_size[0]  # Patch大小(假设正方形Patch)if self.audio_exp:  # 处理音频频谱图(1通道)if self.use_custom_patch:  # 重叠Patch(自定义划分)h, w = self.patch_embed.patch_hw  # Patch的行数和列数(如101×12)# 用unfold生成重叠Patch:在H和W维度按stride滑动窗口x = imgs.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride)# 形状转换:[N, 1, h, w, p, p] → [N, h*w, p²×1]x = x.reshape(shape=(imgs.shape[0], h*w, p**2 * 1))else:  # 非重叠Patch(默认划分)h = imgs.shape[2] // p  # 行数 = 高度//Patch大小w = imgs.shape[3] // p  # 列数 = 宽度//Patch大小# 形状转换:[N, 1, H, W] → [N, 1, h, p, w, p] → 调整维度顺序 → [N, h*w, p²×1]x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)  # 调整维度顺序,将Patch内像素放在最后x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))else:  # 处理图像(3通道)h = w = imgs.shape[2] // p  # 图像通常为正方形,h=w# 形状转换:[N, 3, H, W] → [N, 3, h, p, w, p] → 调整维度顺序 → [N, h*w, p²×3]x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))return x
关键逻辑与分情况解析

patchify的处理逻辑根据数据类型(音频/图像)和Patch划分方式(重叠/非重叠)有所不同,核心是将原始数据的空间维度(H, W)转换为“Patch数量×单个Patch特征”的序列维度。

1. 音频频谱图处理(self.audio_exp=True

音频频谱图通常为单通道(C=1),且具有时间(H)和频率(W)的二维结构,需保留时序/频域连续性,因此支持重叠Patch

  • 重叠Patch(self.use_custom_patch=True
    适用于需要保留局部连续性的场景(如音频时序特征),通过torch.Tensor.unfold实现滑动窗口分割:

    • unfold(dim, size, step):在指定维度(dim=2为H,dim=3为W)上,以size=self.patch_size为窗口大小,step=self.stride为滑动步长,生成重叠窗口。
    • 例如:输入频谱图[N, 1, 1024, 128]patch_size=16stride=10,则H维度生成(1024-16)/10 + 1 = 101个窗口,W维度生成(128-16)/10 + 1 = 12个窗口,最终得到h=101w=12个Patch,形状为[N, 1, 101, 12, 16, 16],reshape后为[N, 101×12, 16²×1] = [N, 1212, 256]
  • 非重叠Patch(self.use_custom_patch=False
    适用于频谱图尺寸能被patch_size整除的场景,通过reshapeeinsum分割:

    • 先将[N, 1, H, W]拆分为[N, 1, h, p, w, p]h=H//pw=W//p),再用torch.einsum('nchpwq->nhwpqc'调整维度顺序(将Patch内的p×p像素放在最后),最后合并为[N, h×w, p²×1]
2. 图像处理(self.audio_exp=False

图像通常为3通道(C=3),空间结构规则,采用非重叠Patch

  • 输入[N, 3, H, W](如[2, 3, 224, 224]),patch_size=16,则h=w=224//16=14,共14×14=196个Patch。
  • 先拆分为[N, 3, 14, 16, 14, 16],通过einsum调整维度顺序为[N, 14, 14, 16, 16, 3],最终reshape为[N, 196, 16²×3] = [N, 196, 768]
核心作用

将原始二维数据转换为模型可处理的序列格式([N, L, D]D=p²×C),作为编码器的输入(后续会被掩码并编码)。

二、unpatchify:将Patch序列还原为原始数据
功能

将形状为 [N, L, p²×C] 的Patch序列(模型解码器输出的重建结果)还原为 [N, C, H, W] 的原始数据形状,用于与输入数据对比计算重建损失。

方法定义
def unpatchify(self, x):"""x: (N, L, patch_size**2 * C)  # Patch序列(重建结果)specs: (N, C, H, W)  # 还原的原始数据形状(音频频谱图)"""p = self.patch_embed.patch_size[0]  # Patch大小# 音频频谱图的高度和宽度(H=1024,W=128,根据实际场景调整)h = 1024 // p  w = 128 // p  # 形状转换:[N, L, p²×1] → [N, h, w, p, p, 1]x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))# 调整维度顺序:[N, h, w, p, p, 1] → [N, 1, h, p, w, p]x = torch.einsum('nhwpqc->nchpwq', x)# 合并Patch内的像素,还原为原始尺寸:[N, 1, h*p, w*p] = [N, 1, 1024, 128]specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))return specs
关键逻辑

unpatchifypatchify的逆操作,核心是将“Patch序列”的维度重新映射回“原始数据”的空间维度(H, W):

  1. reshape拆分:将[N, L, p²×1]拆分为[N, h, w, p, p, 1],其中h×w=L(Patch总数),h=H//pw=W//p
  2. 维度顺序调整:通过torch.einsum('nhwpqc->nchpwq'将通道维度(c)提前,得到[N, 1, h, p, w, p],与patchify中的中间形状对应。
  3. 合并还原:将hp合并为H=h×pwp合并为W=w×p,最终得到[N, 1, H, W]的原始频谱图形状。
适配场景

代码中unpatchify主要适配音频频谱图(C=1),若处理图像(C=3),只需将通道数改为3,并调整hw的计算(如h=w=224//16=14),逻辑完全一致。

三、patchifyunpatchify的互逆性与核心作用
  1. 互逆性
    对原始数据imgs执行unpatchify(patchify(imgs)),结果应与imgs完全一致(忽略数值误差),这是确保重建损失计算准确的前提。

  2. 在MAE中的作用

    • patchify:将输入imgs转换为Patch序列,作为编码器的输入(后续被掩码、编码)。
    • unpatchify:将解码器输出的重建Patch序列(pred)还原为原始数据形状,与imgs计算MSE损失(forward_loss中使用),迫使模型学习从部分Patch重建完整数据。
总结

patchifyunpatchify是MAE中连接“原始数据”与“模型序列输入/输出”的桥梁:

  • patchify通过滑动窗口(重叠)或均匀分割(非重叠)将二维数据转换为Patch序列,适配模型的Transformer结构;
  • unpatchify通过维度重组将Patch序列还原为原始形状,确保重建损失的准确计算。
    两者的设计充分考虑了音频(单通道、重叠Patch)和图像(三通道、非重叠Patch)的差异,体现了模型的多模态适配能力。

2.4 掩码机制:random_maskingrandom_masking_2d

在掩码自编码器(MAE)中,掩码机制是核心设计之一,其作用是随机掩盖输入数据的部分Patch,迫使模型从剩余的少量Patch中学习重建完整数据,从而提取更鲁棒的特征表示。代码中实现了两种掩码机制:random_masking(1D序列掩码)和random_masking_2d(2D结构化掩码),分别适配不同的数据结构(图像的1D Patch序列 vs 音频频谱图的2D时间-频率结构)。以下是详细解析:

一、random_masking:1D序列掩码(适用于图像等1D Patch序列)
功能

将输入的Patch序列(形状[N, L, D]N=批次,L=总Patch数,D=特征维度)视为一维序列,按指定比例(mask_ratio)随机掩盖部分Patch,输出掩码后的序列、掩码矩阵及原始顺序恢复索引。

方法定义
def random_masking(self, x, mask_ratio):"""对一维Patch序列进行随机掩码x: [N, L, D]  # 输入的Patch序列返回:x_masked: [N, L_keep, D]  # 掩码后保留的Patch序列(L_keep = L*(1-mask_ratio))mask: [N, L]  # 掩码矩阵(0=保留,1=掩盖)ids_restore: [N, L]  # 恢复原始顺序的索引"""N, L, D = x.shape  # 解析输入维度len_keep = int(L * (1 - mask_ratio))  # 计算需要保留的Patch数量# 生成随机噪声(用于决定哪些Patch被保留)noise = torch.rand(N, L, device=x.device)  # 形状[N, L],值在[0,1)之间# 对噪声排序,得到掩盖/保留的索引ids_shuffle = torch.argsort(noise, dim=1)  # 按噪声升序排序,小值对应保留的Patchids_restore = torch.argsort(ids_shuffle, dim=1)  # 用于恢复原始顺序的索引(对shuffle索引再排序)# 提取需要保留的Patchids_keep = ids_shuffle[:, :len_keep]  # 取前len_keep个索引(噪声最小的Patch)x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))  # 按索引提取保留的Patch# 生成掩码矩阵(0=保留,1=掩盖)mask = torch.ones([N, L], device=x.device)  # 初始化全为1(默认掩盖)mask[:, :len_keep] = 0  # 前len_keep个位置设为0(保留)mask = torch.gather(mask, dim=1, index=ids_restore)  # 按原始顺序恢复掩码矩阵return x_masked, mask, ids_restore
关键步骤解析
  1. 确定保留数量:根据mask_ratio计算保留的Patch数量(len_keep = L*(1-mask_ratio)),例如mask_ratio=0.75时保留25%的Patch。
  2. 随机噪声生成:生成[N, L]的随机噪声,每个Patch对应一个噪声值,用于后续排序(噪声越小的Patch越可能被保留)。
  3. 索引排序
    • ids_shuffle:对噪声按行排序的索引(升序),即“噪声最小的Patch排在最前”。
    • ids_keep:取ids_shuffle的前len_keep个索引,即需要保留的Patch索引。
  4. 提取保留的Patch:用torch.gatherids_keep从原始序列x中提取保留的Patch,得到x_masked(形状[N, len_keep, D])。
  5. 生成掩码矩阵
    • 初始掩码全为1(表示掩盖),前len_keep个位置设为0(表示保留)。
    • ids_restoreids_shuffle的逆排序)将掩码矩阵恢复为原始Patch顺序,确保mask[i,j]对应原始序列中第j个Patch是否被掩盖。
核心特点
  • 无差别掩码:将Patch视为一维序列,不区分空间/时序位置,按比例随机掩盖,适用于图像等“全局结构无明显维度差异”的数据。
  • 高效性:通过噪声排序实现随机掩码,避免显式采样,计算效率高。
二、random_masking_2d:2D结构化掩码(适用于音频频谱图等2D结构)
功能

针对具有二维结构的Patch(如音频频谱图的“时间-频率”维度),分别对时间维度(T)和频率维度(F)按指定比例(mask_t_probmask_f_prob)进行掩码,最终输出掩码后的序列、合并的掩码矩阵及原始顺序恢复索引。相比1D掩码,2D掩码更贴合数据的物理结构(如音频的时间连续性和频率相关性)。

方法定义
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""对二维结构的Patch(时间T×频率F)进行掩码x: [N, L, D]  # 输入的Patch序列(L = T*F)返回:x_masked: [N, L_keep, D]  # 掩码后保留的Patch序列(L_keep = T_keep*F_keep)mask: [N, L]  # 合并的掩码矩阵(0=保留,1=掩盖)ids_restore: [N, L]  # 恢复原始顺序的索引"""N, L, D = x.shape  # 解析输入维度# 定义时间(T)和频率(F)维度的大小(根据Patch划分方式)if self.use_custom_patch:  # 重叠Patch(音频场景)T = 101  # 时间维度Patch数F = 12   # 频率维度Patch数else:  # 非重叠PatchT = 64F = 8# 计算时间和频率维度保留的数量len_keep_t = int(T * (1 - mask_t_prob))  # 时间维度保留数len_keep_f = int(F * (1 - mask_f_prob))  # 频率维度保留数# -------------------------- 时间维度掩码 --------------------------noise_t = torch.rand(N, T, device=x.device)  # 时间维度随机噪声[N, T]ids_shuffle_t = torch.argsort(noise_t, dim=1)  # 时间维度排序索引(升序)ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)  # 时间维度恢复索引ids_keep_t = ids_shuffle_t[:, :len_keep_t]  # 时间维度保留的索引# -------------------------- 频率维度掩码 --------------------------noise_f = torch.rand(N, F, device=x.device)  # 频率维度随机噪声[N, F]ids_shuffle_f = torch.argsort(noise_f, dim=1)  # 频率维度排序索引(升序)ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)  # 频率维度恢复索引ids_keep_f = ids_shuffle_f[:, :len_keep_f]  # 频率维度保留的索引# -------------------------- 合并掩码矩阵 --------------------------# 生成时间维度掩码(0=保留,1=掩盖)并扩展为[N, T, F]mask_t = torch.ones(N, T, device=x.device)mask_t[:, :len_keep_t] = 0mask_t = torch.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1, F, 1).permute(0, 2, 1)  # [N, T, F]# 生成频率维度掩码(0=保留,1=掩盖)并扩展为[N, T, F]mask_f = torch.ones(N, F, device=x.device)mask_f[:, :len_keep_f] = 0mask_f = torch.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1, T, 1)  # [N, T, F]# 合并掩码:时间或频率任一被掩盖,则该Patch被掩盖(1 - (1-时间掩码)*(1-频率掩码))mask = 1 - (1 - mask_t) * (1 - mask_f)  # [N, T, F]mask = mask.flatten(start_dim=1)  # 展平为[N, L](L=T*F)# -------------------------- 提取保留的Patch --------------------------# 生成原始Patch索引,并给被掩盖的Patch加一个大值(确保排序后被放到最后)id2res = torch.Tensor(list(range(N*T*F))).reshape(N, T, F).to(x.device)  # [N, T, F]的原始索引id2res = id2res + 999 * mask  # 被掩盖的Patch索引 += 999(值变大)# 展平后排序,取前L_keep个索引(未被掩盖的Patch)id2res2 = torch.argsort(id2res.flatten(start_dim=1))  # 按索引值升序排序ids_keep = id2res2.flatten(start_dim=1)[:, :len_keep_t*len_keep_f]  # 保留的Patch索引x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))  # 提取保留的Patch# 生成恢复原始顺序的索引ids_restore = torch.argsort(id2res2.flatten(start_dim=1))  # [N, L]return x_masked, mask, ids_restore
关键步骤解析

2D掩码的核心是区分时间和频率维度独立掩码,再合并结果,步骤如下:

  1. 维度定义:根据use_custom_patch确定二维结构的时间维度(T)和频率维度(F),例如音频重叠Patch时T=101F=12(总Patch数L=T*F=1212)。

  2. 时间维度掩码

    • 生成[N, T]的随机噪声noise_t,排序后得到保留的时间索引ids_keep_t(数量len_keep_t = T*(1-mask_t_prob))。
    • 生成时间掩码mask_t[N, T],0=保留,1=掩盖),扩展为[N, T, F](与频率维度对齐)。
  3. 频率维度掩码

    • 生成[N, F]的随机噪声noise_f,排序后得到保留的频率索引ids_keep_f(数量len_keep_f = F*(1-mask_f_prob))。
    • 生成频率掩码mask_f[N, F],0=保留,1=掩盖),扩展为[N, T, F](与时间维度对齐)。
  4. 合并掩码矩阵

    • 合并逻辑:mask = 1 - (1 - mask_t) * (1 - mask_f),即“时间或频率任一维度被掩盖,则该Patch被掩盖”(避免仅掩盖单维度导致信息残留)。
    • 合并后展平为[N, L]L=T*F),与原始Patch序列长度一致。
  5. 提取保留的Patch

    • 为原始Patch索引(id2res)中被掩盖的位置加“大值”(999),确保排序后这些Patch被放在最后。
    • 对索引排序后,取前len_keep_t*len_keep_f个索引(未被掩盖的Patch),用torch.gather提取得到x_masked
  6. 生成恢复索引ids_restore为排序索引的逆,用于解码器将掩码后的序列还原到原始顺序,确保重建时Patch位置对齐。

核心特点
  • 维度感知:区分时间和频率维度分别掩码,更贴合音频频谱图等二维数据的物理结构(如时间连续性、频率相关性)。
  • 灵活控制:通过mask_t_probmask_f_prob独立控制两个维度的掩码强度(例如对时间维度掩盖60%,频率维度掩盖50%)。
  • 严格掩盖:合并掩码时采用“或”逻辑,确保被掩盖的Patch在至少一个维度上缺失信息,增强重建难度,迫使模型学习更本质的特征。
三、两种掩码机制的对比与适用场景
对比项random_masking(1D)random_masking_2d(2D)
数据结构假设一维Patch序列(无维度差异)二维结构化Patch(如时间-频率)
掩码粒度整个序列无差别掩码区分两个维度独立掩码
掩码比例控制单一mask_ratio控制整体比例mask_t_probmask_f_prob分别控制
适用场景图像(正方形Patch,全局结构)音频频谱图(时间-频率结构)
核心优势实现简单,计算高效贴合二维数据物理结构,特征学习更有效
总结

random_maskingrandom_masking_2d是MAE中实现“信息缺失驱动学习”的核心机制:

  • 1D掩码适用于无明显维度差异的序列数据,通过无差别掩码迫使模型学习全局结构;
  • 2D掩码适用于二维结构化数据(如音频频谱图),通过维度感知的掩码策略,更精准地模拟真实场景中的信息缺失(如局部时间片段或频率成分丢失)。

两种机制均通过“随机噪声排序-索引提取-掩码矩阵生成”的流程实现高效掩码,并通过恢复索引确保解码器能正确还原序列顺序,为后续重建任务奠定基础。

2.5 编码器前向传播:forward_encoderforward_encoder_no_mask

MaskedAutoencoderViT 模型中,编码器的前向传播方法分为两种:forward_encoder(带掩码的前向传播)和 forward_encoder_no_mask(无掩码的前向传播)。两者均基于 Transformer 架构提取输入数据的特征,但适用场景和处理流程存在显著差异:

  • forward_encoder 是 MAE 自监督训练的核心,通过随机掩码部分 Patch 迫使模型从有限信息中学习数据结构;
  • forward_encoder_no_mask 用于处理完整输入(无掩码),主要用于预训练后的特征提取或下游任务微调。
一、forward_encoder:带掩码的编码器前向传播(自监督训练核心)
功能

接收原始输入数据,通过 Patch 嵌入、位置编码、随机掩码部分 Patch 后,经 Transformer 编码器提取特征,输出掩码后的特征序列、掩码矩阵及原始顺序恢复索引,为解码器的重建任务提供输入。

方法定义
def forward_encoder(self, x, mask_ratio, mask_2d=False):# 1. 对输入数据进行Patch嵌入(转换为Patch序列)x = self.patch_embed(x)# 2. 为Patch添加位置嵌入(不含cls_token的位置)x = x + self.pos_embed[:, 1:, :]# 3. 对Patch序列进行随机掩码(1D或2D掩码)if mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_ratio)# 4. 拼接cls_token(并添加其位置嵌入)cls_token = self.cls_token + self.pos_embed[:, :1, :]  # cls_token的位置嵌入是pos_embed的第0位cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # 扩展到批次维度:[1,1,D] → [N,1,D]x = torch.cat((cls_tokens, x), dim=1)  # 拼接:[N, L_keep, D] → [N, L_keep+1, D](+1为cls_token)# 5. 通过Transformer编码器块提取特征for blk in self.blocks:x = blk(x)x = self.norm(x)  # 编码器输出归一化return x, mask, ids_restore, None  # 返回编码特征、掩码、恢复索引
关键步骤解析
1. Patch嵌入(self.patch_embed(x)
  • 输入 x 为原始数据(图像/音频频谱图,形状 [N, C, H, W]),经 patch_embed 转换为 Patch 序列(形状 [N, L, embed_dim]L 为总 Patch 数,embed_dim 为编码器特征维度)。
  • 例如:音频频谱图 [N, 1, 1024, 128] 经重叠 Patch 嵌入后变为 [N, 1212, 1024]L=101×12=1212embed_dim=1024)。
2. 添加位置嵌入(x = x + self.pos_embed[:, 1:, :]
  • 位置嵌入 self.pos_embed 形状为 [1, L+1, embed_dim]+1cls_token 预留位置),此处仅取 [:, 1:, :](即 Patch 对应的位置嵌入),与 Patch 序列 x 相加,赋予模型空间/时序位置感知能力。
3. 随机掩码(random_maskingrandom_masking_2d
  • 根据 mask_2d 选择掩码方式:
    • 1D 掩码(mask_2d=False):调用 random_masking,按 mask_ratio 随机掩盖部分 Patch,输出掩码后的序列 x(形状 [N, L_keep, embed_dim]L_keep 为保留的 Patch 数)。
    • 2D 掩码(mask_2d=True):调用 random_masking_2d,按 mask_t_prob(时间维度)和 mask_f_prob(频率维度)掩盖 Patch,输出掩码后的序列。
  • 同时返回 mask(掩码矩阵,[N, L],0=保留,1=掩盖)和 ids_restore(恢复原始顺序的索引,[N, L]),用于解码器还原序列顺序。
4. 拼接 cls_tokentorch.cat((cls_tokens, x), dim=1)
  • cls_token 是可学习的全局特征向量(形状 [1, 1, embed_dim]),先添加其专属位置嵌入(self.pos_embed[:, :1, :]),再扩展到批次维度([N, 1, embed_dim])。
  • 与掩码后的 Patch 序列 x 拼接,得到 [N, L_keep+1, embed_dim]+1cls_token),使 cls_token 能聚合所有保留 Patch 的全局信息。
5. Transformer 编码(for blk in self.blocks: x = blk(x)
  • 输入序列经 depth 个 Transformer 块(self.blocks)处理,每个块包含“多头自注意力”和“MLP”,通过残差连接和层归一化提取高层特征。
  • 最终经 self.norm 归一化,输出编码特征 x(形状 [N, L_keep+1, embed_dim])。
输出与作用
  • 输出:编码特征 x(掩码后 Patch 序列+cls_token 的特征)、掩码矩阵 mask、原始顺序索引 ids_restore
  • 作用:为解码器提供“已编码的保留 Patch 特征”,结合 maskids_restore,使解码器能定位被掩盖的 Patch 位置并重建。
二、forward_encoder_no_mask:无掩码的编码器前向传播(特征提取/微调)
功能

处理完整输入数据(不进行掩码),通过 Transformer 编码器提取特征,主要用于预训练后对完整数据的特征提取(如下游分类任务微调),或获取上下文丰富的特征表示。

方法定义
def forward_encoder_no_mask(self, x):# 1. 对输入数据进行Patch嵌入x = self.patch_embed(x)# 2. 为Patch添加位置嵌入(不含cls_token的位置)x = x + self.pos_embed[:, 1:, :]# 3. 拼接cls_token(并添加其位置嵌入)cls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)  # [N, L+1, embed_dim](L为总Patch数,无掩码)# 4. 通过Transformer编码器块,收集深层上下文特征contextual_embs = []for n, blk in enumerate(self.blocks):x = blk(x)# 收集contextual_depth之后的层输出(经归一化)if n > self.contextual_depth:contextual_embs.append(self.norm(x))# 对收集的深层特征取平均,作为最终上下文特征contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)return contextual_emb
关键步骤解析
1-3. Patch嵌入、位置编码与 cls_token 拼接
  • 流程与 forward_encoder 前3步一致,但不进行掩码,因此 Patch 序列为完整序列(L 个 Patch,无删减),拼接 cls_token 后形状为 [N, L+1, embed_dim]
4. 收集深层上下文特征(contextual_embs
  • forward_encoder 直接通过所有 Transformer 块后输出不同,forward_encoder_no_mask收集多个深层 Transformer 块的输出
    • self.contextual_depth 为阈值(如8),仅收集索引 n > contextual_depth 的 Transformer 块输出(即更深层的特征)。
    • 每个深层输出经 self.norm 归一化后存入 contextual_embs,最终通过 torch.stack 拼接并取平均(mean(dim=0)),得到 contextual_emb(形状 [N, L+1, embed_dim])。
设计动机

深层 Transformer 块的输出通常包含更抽象、更全局的特征(浅层特征偏向局部细节),通过收集多个深层特征并取平均,可获得更鲁棒的上下文特征,提升下游任务(如分类)的性能。

输出与作用
  • 输出:contextual_emb(深层 Transformer 特征的平均值,[N, L+1, embed_dim])。
  • 作用:为下游任务提供高质量特征(如取 contextual_emb[:, 0, :] 作为全局特征用于分类),或用于模型微调阶段的特征学习。
三、两种编码器前向传播的对比
对比项forward_encoder(带掩码)forward_encoder_no_mask(无掩码)
核心操作包含随机掩码步骤(掩盖部分Patch)无掩码,使用完整Patch序列
输入处理输出掩码后的Patch序列特征输出完整Patch序列的深层特征平均值
关键输出编码特征、掩码矩阵、恢复索引上下文特征(深层特征平均)
适用场景MAE自监督预训练(配合解码器重建)预训练后特征提取、下游任务微调
设计目标迫使模型从有限信息学习数据结构提取鲁棒的完整特征用于下游任务
序列长度短(L_keep + 1,仅保留部分Patch)长(L + 1,包含所有Patch)
总结

forward_encoderforward_encoder_no_mask 是编码器针对不同阶段设计的前向传播方法:

  • forward_encoder 是 MAE 自监督训练的核心,通过掩码机制创造“信息缺失”场景,驱动模型学习数据的内在结构,为解码器的重建任务提供输入;
  • forward_encoder_no_mask 专注于完整输入的特征提取,通过聚合深层 Transformer 特征,为下游任务提供高质量的上下文特征,实现预训练模型的迁移应用。

两者共同构成了 MAE“预训练-微调”全流程的编码器逻辑,兼顾了自监督学习的特征学习能力和下游任务的实用性。

2.6 解码器前向传播:forward_decoder

forward_decoder 是掩码自编码器(MAE)中负责“重建被掩盖Patch”的核心方法。它接收编码器输出的“未掩盖Patch特征”,结合掩码Token和位置信息,通过解码器网络(Transformer或Swin Transformer)重建完整的Patch序列,最终输出可与原始数据对比的重建结果。以下是详细解析:

功能与输入输出
  • 功能:从编码器输出的“掩码后特征”中恢复完整的Patch序列,实现被掩盖区域的重建。
  • 输入
    • x:编码器输出的特征(形状 [N, L_keep+1, embed_dim]L_keep 为未掩盖的Patch数,+1cls_token)。
    • ids_restore:用于将掩码后的序列恢复为原始Patch顺序的索引(形状 [N, L]L 为总Patch数)。
  • 输出
    • pred:重建的Patch序列(形状 [N, L, p²×C]p 为Patch大小,C 为通道数),可通过 unpatchify 还原为原始数据形状。
方法定义与详细步骤
def forward_decoder(self, x, ids_restore):# 1. 将编码器特征投影到解码器维度x = self.decoder_embed(x)  # [N, L_keep+1, embed_dim] → [N, L_keep+1, decoder_embed_dim]# 2. 生成并拼接掩码Token(填充被掩盖的Patch位置)# 计算需要填充的掩码Token数量:总Patch数 + 1(cls_token) - 编码器输出长度mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)# 拼接未掩盖的Patch特征(不含cls_token)和掩码Token → [N, L_keep + (L - L_keep), decoder_embed_dim] = [N, L, decoder_embed_dim]x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # 去掉cls_token,拼接掩码Token# 按ids_restore恢复原始Patch顺序(确保位置对齐)x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))# 拼接回cls_token → [N, L+1, decoder_embed_dim](L+1 = 总Patch数 + cls_token)x = torch.cat([x[:, :1, :], x_], dim=1)# 3. 添加解码器位置嵌入(确保位置信息对齐)x = x + self.decoder_pos_embed  # [N, L+1, decoder_embed_dim]# 4. 适配Swin解码器的二维结构(若启用)if self.decoder_mode != 0:B, L, D = x.shapex = x[:, 1:, :]  # 移除cls_token,仅保留Patch特征 → [N, L, decoder_embed_dim]if self.use_custom_patch:  # 重叠Patch(音频场景)# 调整形状为二维结构(时间T×频率F),并补全尺寸(避免维度不匹配)x = x.reshape(B, 101, 12, D)x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1)  # 补全一行,适配Swin窗口x = x.reshape(B, 1224, D)  # 102×12=1224# 5. 通过解码器块提取特征(Transformer或Swin Transformer)if self.decoder_mode > 3:  # 预留的mvit解码器模式x = self.decoder_blocks(x)else:for blk in self.decoder_blocks:x = blk(x)  # 逐块处理# 6. 解码器输出归一化x = self.decoder_norm(x)  # [N, L+1 (或L), decoder_embed_dim]# 7. 预测Patch的像素值(投影到原始Patch维度)pred = self.decoder_pred(x)  # [N, L+1 (或L), p²×C]# 8. 移除cls_token,仅保留Patch的重建结果if self.decoder_mode != 0:if self.use_custom_patch:# 还原重叠Patch的形状,去掉补全的部分pred = pred.reshape(B, 102, 12, 256)pred = pred[:, :101, :, :]  # 去掉补全的一行pred = pred.reshape(B, 1212, 256)  # 101×12=1212(原始总Patch数)else:pred = pred[:, 1:, :]  # 普通解码器直接移除cls_tokenreturn pred, None, None  # 返回重建的Patch序列
关键步骤解析
1. 编码器特征投影(self.decoder_embed(x)
  • 编码器输出特征的维度为 embed_dim(如1024),而解码器的特征维度为 decoder_embed_dim(如512)。这一步通过线性层(self.decoder_embed)将维度转换,确保与解码器后续层的输入维度匹配。
  • 形状变化:[N, L_keep+1, embed_dim] → [N, L_keep+1, decoder_embed_dim]
2. 掩码Token拼接与顺序恢复

这是解码器重建的核心步骤,目的是填充被掩盖的Patch位置并恢复原始顺序

  • 掩码Token生成mask_tokens 是可学习的向量(self.mask_token),数量为 ids_restore.shape[1] + 1 - x.shape[1](即总Patch数 L + 1(cls_token) - 编码器输出长度 L_keep+1 = L - L_keep,恰好等于被掩盖的Patch数)。
  • 拼接未掩盖特征与掩码Tokenx_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) 去掉编码器输出中的 cls_token,拼接掩码Token,得到长度为 L_keep + (L - L_keep) = L 的序列(所有Patch位置均被填充,未掩盖位置为编码器特征,掩盖位置为掩码Token)。
  • 恢复原始顺序torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) 利用 ids_restore 将拼接后的序列重新排列为原始Patch顺序(确保重建的Patch与输入数据的位置一一对应)。
  • 拼接cls_token:最后将 cls_token 拼接回序列,得到 [N, L+1, decoder_embed_dim](与解码器位置嵌入的形状匹配)。
3. 添加解码器位置嵌入(x = x + self.decoder_pos_embed
  • 解码器位置嵌入(self.decoder_pos_embed)形状为 [1, L+1, decoder_embed_dim],与编码器位置嵌入结构一致,但维度适配解码器。添加位置嵌入后,解码器能正确理解Patch的空间/时序位置关系,确保重建的空间一致性。
4. 适配Swin解码器的二维结构(if self.decoder_mode != 0

当使用Swin Transformer作为解码器(decoder_mode=1)时,需将序列特征调整为二维结构(适配Swin的窗口注意力机制):

  • 移除cls_token:Swin解码器专注于Patch的空间重建,无需 cls_token,因此去掉 x[:, 1:, :]
  • 形状调整:对于重叠Patch(音频场景),将序列 [N, L, D] 重塑为 [N, T, F, D]T=时间维度Patch数,F=频率维度Patch数,如101×12),并通过 torch.cat 补全一行(适配Swin的窗口大小,避免维度不匹配),最后重塑回序列格式输入解码器块。
5. 解码器块特征提取(for blk in self.decoder_blocks: x = blk(x)
  • 根据 decoder_mode 选择解码器类型:
    • 普通Transformer解码器(decoder_mode=0:通过 decoder_depth 个标准Transformer块(Block)处理,每个块含多头自注意力和MLP,捕捉Patch间的长距离依赖,优化重建特征。
    • Swin Transformer解码器(decoder_mode=1:通过16个 SwinTransformerBlock 处理,利用窗口注意力和移位窗口机制,更高效地捕捉二维结构(如音频时间-频率)的局部相关性。
6. 归一化与像素预测(self.decoder_norm(x)self.decoder_pred(x)
  • 归一化self.decoder_norm 对解码器块的输出进行层归一化,稳定特征分布。
  • 像素预测self.decoder_pred 是线性层,将解码器特征(decoder_embed_dim)投影到“单个Patch的像素维度”(p²×C,如16×16×1=256 for 音频单通道),得到每个Patch的重建像素值。
7. 移除cls_token(最终输出处理)
  • 重建目标是原始数据的Patch序列,因此需移除 cls_token
    • 普通解码器直接取 pred[:, 1:, :]
    • Swin解码器(重叠Patch)需先还原二维形状,去掉之前补全的行,再重塑为原始总Patch数 L,确保输出形状与 patchify 后的原始Patch序列一致。
核心作用与设计逻辑

解码器的核心目标是从“部分观察”(未掩盖的Patch)推断“全局完整信息”(被掩盖的Patch),其设计逻辑围绕以下原则:

  1. 信息补全:通过掩码Token填充被掩盖位置,为解码器提供完整的序列结构。
  2. 位置对齐:解码器位置嵌入确保重建的Patch与原始数据的空间/时序位置一一对应。
  3. 结构适配:支持Transformer(全局注意力)和Swin Transformer(局部窗口注意力),分别适配图像的全局结构和音频的二维时间-频率结构。
  4. 维度匹配:通过投影层确保编码器特征与解码器维度兼容,最终预测层输出与原始Patch的像素维度一致,为损失计算(forward_loss)提供可对比的重建结果。
总结

forward_decoder 是MAE实现“重建任务”的核心模块,通过“特征投影-掩码填充-顺序恢复-位置编码-深度特征提取-像素预测”的流程,将编码器输出的有限信息转化为完整的Patch序列重建结果。其设计兼顾了多模态数据(图像/音频)的结构差异,通过灵活的解码器类型(Transformer/Swin)和形状调整,确保重建精度和效率,最终辅助模型学习数据的内在结构特征。

2.7 损失计算:forward_loss

forward_loss 是掩码自编码器(MAE)中计算“重建损失”的核心方法,用于衡量模型重建的被掩码Patch与原始数据中对应Patch的差异。该损失是MAE自监督训练的优化目标,直接驱动模型学习从部分信息中重建完整数据的能力。以下是详细解析:

功能与输入输出
  • 功能:仅针对被掩码的Patch计算重建误差(忽略未掩码的Patch),确保模型专注于学习从有限信息中恢复缺失内容。
  • 输入
    • imgs:原始输入数据(形状 [N, C, H, W]N=批次,C=通道,H/W=高度/宽度)。
    • pred:解码器输出的重建Patch序列(形状 [N, L, p²×C]L=总Patch数,p=Patch大小)。
    • mask:掩码矩阵(形状 [N, L],0=未掩码,1=被掩码)。
  • 输出
    • 被掩码Patch的平均重建损失(标量,通常为MSE损失)。
方法定义与详细步骤
def forward_loss(self, imgs, pred, mask):"""imgs: [N, C, H, W]  # 原始输入数据pred: [N, L, p²×C]  # 解码器重建的Patch序列mask: [N, L]  # 掩码矩阵(1表示被掩码的Patch)"""# 1. 将原始图像分割为Patch序列(与pred的形状匹配)target = self.patchify(imgs)  # [N, L, p²×C]# 2. 计算重建Patch与原始Patch的MSE损失(逐元素)loss = (pred - target) **2  # [N, L, p²×C],每个元素的平方误差loss = loss.mean(dim=-1)  # 对单个Patch内的所有像素求平均 → [N, L]# 3. 仅计算被掩码Patch的损失(mask=1的位置),并求平均loss = (loss * mask).sum() / mask.sum()  # 被掩码Patch的平均损失return loss

#####** 关键步骤解析 **

1. 原始数据Patch化(target = self.patchify(imgs)
  • 调用 patchify 方法将原始输入 imgs 分割为与 pred 形状完全一致的Patch序列 target(形状 [N, L, p²×C])。
  • 这一步确保“重建结果”与“原始数据”在Patch级别对齐,可直接计算逐Patch的误差。
  • 例如:原始音频频谱图 [N, 1, 1024, 128]patchify 后变为 [N, 1212, 256]L=101×12=1212p=16p²×C=16²×1=256),与 pred 形状完全匹配。
2. 逐元素MSE计算(`loss = (pred - target)

-** 逐元素平方误差 :计算重建Patch pred 与原始Patch target 的差值平方,得到形状 [N, L, p²×C] 的误差矩阵(每个元素对应单个像素的重建误差)。
-
单个Patch内平均 **:通过 loss.mean(dim=-1) 对每个Patch的所有像素(p²×C 维度)求平均,得到 [N, L] 的误差矩阵(每个元素对应一个Patch的平均重建误差)。

#####** 3. 掩码Patch的损失聚合(loss = (loss * mask).sum() / mask.sum()- 这是MAE损失计算的核心设计: 仅关注被掩码的Patch **,忽略未掩码的Patch(因为未掩码的Patch已被编码器直接观察到,重建它们无法体现模型的推断能力)。

  • loss * mask:通过掩码矩阵 mask(1=被掩码,0=未掩码)过滤误差,仅保留被掩码Patch的误差(未掩码Patch的误差被置为0)。
  • sum():对所有被掩码Patch的误差求和。
  • / mask.sum():除以被掩码Patch的总数量(mask.sum() 得到批次中被掩码的Patch总数),得到被掩码Patch的平均重建损失。

###** 核心设计原则 1. 聚焦掩码区域 :仅计算被掩码Patch的损失,迫使模型学习从“未观察信息”(未掩码Patch)推断“缺失信息”(被掩码Patch),而非简单复制已观察内容。
2.
平均化处理 **:

  • 先对单个Patch内的所有像素求平均(mean(dim=-1)),避免单个Patch因像素数量多(如 p=16 时有256个像素)而主导损失。
  • 再对所有被掩码Patch求平均(sum() / mask.sum()),确保不同批次中掩码数量不同时损失仍可比较。
    3.** 与Patch划分一致 **:依赖 patchify 方法确保 predtarget 的形状严格匹配,保证误差计算的准确性。

###** 与其他损失计算的对比 - 全Patch损失 :若计算所有Patch(包括未掩码)的损失,模型可能“偷懒”复制未掩码Patch的信息,而不学习真正的推断能力。
-
逐像素损失 **:若直接对原始图像和重建图像计算逐像素损失(不通过Patch),会包含未掩码区域的误差,违背MAE“掩码重建”的设计初衷。

  • MAE的 forward_loss 通过聚焦掩码区域,精准对齐了自监督学习的目标:** 从部分观察中学习数据的全局结构 **。

###** 总结**
forward_loss 是MAE训练的“指挥棒”,通过以下逻辑驱动模型优化:

  1. 将原始数据和重建结果转换为Patch序列,确保误差计算在相同粒度上进行;
  2. 仅对被掩码的Patch计算平均MSE损失,引导模型专注于学习从有限信息中恢复缺失内容的能力;
  3. 损失的平均化处理保证了不同批次、不同掩码数量下的损失可比较性。

这一设计使MAE能高效学习数据的内在特征,为下游任务(如分类、分割)提供高质量的预训练模型。

2.8 模型总前向传播:forward

forwardMaskedAutoencoderViT 模型的总前向传播方法,它整合了编码器(带掩码)、解码器、损失计算等核心组件,实现了从原始输入数据到重建损失(或重建结果)的完整流程。该方法是模型训练和推理的入口,直接体现了MAE“掩码-编码-解码-重建-损失”的核心逻辑。以下是详细解析:

功能与输入输出
  • 功能:协调模型各模块(编码器、解码器、损失计算),完成“输入数据→掩码编码→解码重建→损失计算”的全流程。
  • 输入
    • imgs:原始输入数据(图像或音频频谱图,形状 [N, C, H, W]N=批次,C=通道,H/W=高度/宽度)。
    • mask_ratio:掩码比例(如0.75,表示掩盖75%的Patch)。
    • mask_2d:是否使用2D掩码(True 用于音频等二维结构数据,False 用于图像)。
  • 输出
    • 训练阶段:返回被掩码Patch的重建损失(标量)。
    • 推理阶段:可额外返回重建的原始数据(通过 unpatchify 还原),用于可视化或评估。
方法定义与详细步骤
def forward(self, imgs, mask_ratio=0.75, mask_2d=False):# 1. 编码器前向传播(带掩码):得到编码特征、掩码矩阵、原始顺序索引latent, mask, ids_restore, _ = self.forward_encoder(imgs, mask_ratio, mask_2d=mask_2d)# 2. 解码器前向传播:根据编码特征和索引重建Patch序列pred, _, _ = self.forward_decoder(latent, ids_restore)  # pred: [N, L, p²×C]# 3. 计算重建损失(仅针对被掩码的Patch)loss = self.forward_loss(imgs, pred, mask)# 4. (可选)还原重建的原始数据形状(用于可视化或推理)pred_spec = self.unpatchify(pred)  # [N, C, H, W],与原始输入形状一致return loss, pred_spec, mask
关键步骤解析
1. 编码器带掩码前向传播(self.forward_encoder
  • 输入:原始数据 imgs、掩码比例 mask_ratio、掩码类型 mask_2d
  • 处理逻辑
    • imgs 转换为Patch序列并添加位置嵌入。
    • 根据 mask_2d 选择1D或2D掩码机制,随机掩盖部分Patch。
    • 通过Transformer编码器提取保留Patch的特征,拼接 cls_token 并归一化。
  • 输出
    • latent:编码器输出的特征(形状 [N, L_keep+1, embed_dim]L_keep 为保留的Patch数,+1cls_token)。
    • mask:掩码矩阵([N, L],1=被掩码,0=未掩码)。
    • ids_restore:用于将掩码后序列恢复为原始顺序的索引([N, L])。
2. 解码器前向传播(self.forward_decoder
  • 输入:编码器特征 latent、恢复索引 ids_restore
  • 处理逻辑
    • latent 投影到解码器维度(decoder_embed_dim)。
    • 生成掩码Token(mask_token),填充被掩码的Patch位置,并通过 ids_restore 恢复原始顺序。
    • 添加解码器位置嵌入,经解码器(Transformer或Swin Transformer)提取特征后,预测每个Patch的像素值。
  • 输出
    • pred:重建的Patch序列(形状 [N, L, p²×C],与 patchify(imgs) 输出的原始Patch序列形状一致)。
3. 重建损失计算(self.forward_loss
  • 输入:原始数据 imgs、重建Patch序列 pred、掩码矩阵 mask
  • 处理逻辑
    • imgs 转换为原始Patch序列(target),与 pred 对齐。
    • 计算 predtarget 的MSE误差,仅对被掩码的Patch(mask=1)求平均,得到最终损失。
  • 输出:被掩码Patch的平均重建损失(标量),作为模型训练的优化目标。
4. 重建结果还原(self.unpatchify(pred)
  • 作用:将重建的Patch序列 pred 还原为与原始输入 imgs 形状一致的数据([N, C, H, W]),用于可视化重建效果(如对比原始图像与模型重建的图像)或推理阶段的结果输出。
数据流向与维度变化

为更清晰展示流程,以音频频谱图(N=2C=1H=1024W=128patch_size=16mask_ratio=0.75mask_2d=True)为例:

  1. 输入数据imgs 形状 [2, 1, 1024, 128]
  2. 编码器处理
    • patch_embed 转换为Patch序列 [2, 1212, 1024]L=101×12=1212embed_dim=1024)。
    • 2D掩码后保留 1212×(1-0.75)=303 个Patch,latent 形状 [2, 303+1, 1024]+1cls_token)。
    • 输出 mask[2, 1212])和 ids_restore[2, 1212])。
  3. 解码器处理
    • decoder_embed 投影为 [2, 304, 512]decoder_embed_dim=512)。
    • 填充 1212-303=909 个掩码Token,恢复顺序后序列长度为1212,经解码器处理后 pred 形状 [2, 1212, 256]p²×C=16²×1=256)。
  4. 损失计算
    • patchify(imgs) 得到原始Patch序列 [2, 1212, 256],与 pred 计算MSE,仅保留 mask=1 的位置,输出损失标量。
  5. 重建结果unpatchify(pred) 还原为 [2, 1, 1024, 128],与原始输入形状一致。
核心作用与设计逻辑

forward 方法是模型的“中枢”,其设计体现了MAE的自监督学习逻辑:

  1. 端到端流程:从原始数据输入到损失输出,整合了“Patch化-掩码-编码-解码-重建-损失”的全链路,确保训练过程可直接优化。
  2. 模块化协同:通过调用 forward_encoderforward_decoderforward_loss 等模块化方法,实现功能解耦,便于维护和扩展(如替换编码器/解码器结构)。
  3. 兼顾训练与推理:训练时返回损失用于优化,推理时返回重建结果用于可视化或评估,满足不同阶段的需求。
  4. 多模态适配:通过 mask_2d 参数切换1D/2D掩码,适配图像(1D序列)和音频频谱图(2D结构)等不同类型数据,体现模型的通用性。
总结

forward 方法是 MaskedAutoencoderViT 模型的核心入口,它通过协调编码器(带掩码)、解码器和损失计算模块,实现了从原始输入到重建损失的完整流程。其设计既保证了MAE“掩码重建”的自监督学习逻辑,又通过模块化和参数控制(如 mask_2d)适配了多模态数据,为模型的训练和应用提供了统一接口。

三、模型实例化函数

代码提供了4种规模的MAE模型实例化函数,参数与标准ViT对应,仅解码器统一为“512维+8层”(平衡性能与计算量):

函数名编码器配置(patch_size=16)解码器配置(固定)
mae_vit_small_patch16embed_dim=384, depth=12, num_heads=6decoder_embed_dim=512, depth=8
mae_vit_base_patch16embed_dim=768, depth=12, num_heads=12decoder_embed_dim=512, depth=8
mae_vit_large_patch16embed_dim=1024, depth=24, num_heads=16decoder_embed_dim=512, depth=8
mae_vit_huge_patch14embed_dim=1280, depth=32, num_heads=16decoder_embed_dim=512, depth=8

四、核心应用场景

该模型主要用于音频频谱图的自监督预训练(通过audio_exp=True、2D掩码、重叠Patch适配),预训练后可微调用于:

  • 音频分类(如环境声分类ESC-50、语音情感识别)。
  • 音频检索、异常音频检测等任务。
  • 也可通过audio_exp=False适配图像任务(如ImageNet分类)。

五、关键设计亮点

  1. 多模态适配:通过audio_exp区分音频/图像,灵活处理1/3通道数据。
  2. 精细掩码:2D掩码区分时间-频率维度,更贴合音频数据的物理结构。
  3. 灵活解码器:支持ViT/Swin解码器,Swin的窗口注意力更适合二维数据的局部相关性建模。
  4. 工程优化:重叠Patch、可训练位置嵌入、像素归一化等设计,提升训练稳定性和特征质量。
http://www.xdnf.cn/news/1381051.html

相关文章:

  • 流程控制语句(3)
  • 帕萨特盘式制动器cad+设计说明书
  • 【C语言16天强化训练】从基础入门到进阶:Day 13
  • week5-[一维数组]归并
  • 公共字段自动填充
  • 云计算学习100天-第29天
  • 基于SamOut的音频Token序列生成模型训练指南
  • Linux shell getopts 解析命令行参数
  • 算力沸腾时代,如何保持“冷静”?国鑫液冷SY4108G-G4解锁AI服务器的“绿色空调”!
  • 使用Rag 命中用户feedback提升triage agent 准确率
  • Elasticsearch数据迁移方案深度对比:三种方法的优劣分析
  • linu 网络 :TCP粘包及UDP
  • 【C++】C++11的右值引用和移动语义
  • STAGEWISE实战指南:从集成到使用的完整解决方案
  • vscode pyqt5设置
  • 【ai编辑器】使用cursor-vip获得cursor的pro版 pro plan(mac)
  • uniapp vue3 canvas实现手写签名
  • Flask测试平台开发,登陆重构
  • (二分查找)Leetcode34. 在排序数组中查找元素的第一个和最后一个位置+74. 搜索二维矩阵
  • 并发编程——05 并发锁机制之深入理解synchronized
  • 学习数据结构(13)二叉树链式结构下
  • 线程池及线程池单例模式
  • 带动态条件的模糊查询SQL
  • DINOv2 vs DINOv3 vs CLIP:自监督视觉模型的演进与可视化对比
  • LeetCode 3446. 按对角线进行矩阵排序
  • UE5提升分辨率和帧率的方法
  • 搭建私有云3步法:cpolar简化Puter本地云端配置
  • C# SIMD编程实践:工业数据处理性能优化案例
  • C++ 哈希概念版
  • 【实战笔记】OCI Ubuntu 24.04 + TigerVNC + XFCE + Chrome 开机自启全记录