audioMAE模型代码分析
VisionTransformer
audioMAE的核心是由ViT模型构成,所以先对ViT模型进行解释。
先给出完整代码:
from functools import partialimport torch
import torch.nn as nn
import numpy as np
import timm.models.vision_transformer
from timm.models.vision_transformer import PatchEmbed, Block
from util.patch_embed import PatchEmbed_new, PatchEmbed3D_newclass VisionTransformer(timm.models.vision_transformer.VisionTransformer):""" Vision Transformer with support for global average pooling"""def __init__(self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs):super(VisionTransformer, self).__init__(**kwargs)self.global_pool = global_poolif self.global_pool:norm_layer = kwargs['norm_layer']embed_dim = kwargs['embed_dim']self.fc_norm = norm_layer(embed_dim)del self.norm # remove the original normself.mask_2d = mask_2dself.use_custom_patch = use_custom_patchnum_heads=12depth=12mlp_ratio=4def forward_features(self, x):B = x.shape[0]x = self.patch_embed(x)x = x + self.pos_embed[:, 1:, :]cls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanksx = torch.cat((cls_tokens, x), dim=1)x = self.pos_drop(x) for blk in self.blocks:x = blk(x)if self.global_pool:x = x[:, 1:, :].mean(dim=1) # global pool without cls tokenoutcome = self.fc_norm(x)else:x = self.norm(x)outcome = x[:, 0]return outcomedef random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape # batch, length, dimlen_keep = int(L * (1 - mask_ratio))noise = torch.rand(N, L, device=x.device) # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoredef random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape # batch, length, dimif self.use_custom_patch:# # for AST=101 #64,101F=12 #8,12# # for ESC# T=50# F=12 # for SPC# T=12# F=12else:# ## for AS T=64F=8# ## for ESC#T=32#F=8 ## for SPC# T=8# F=8# mask Tx = x.reshape(N, T, F, D)len_keep_T = int(T * (1 - mask_t_prob))noise = torch.rand(N, T, device=x.device) # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is removeids_keep = ids_shuffle[:, :len_keep_T]index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)#x_masked = torch.gather(x, dim=1, index=index)#x_masked = x_masked.reshape(N,len_keep_T*F,D)x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D# mask F#x = x.reshape(N, T, F, D)x = x.permute(0,2,1,3) # N T' F D => N F T' Dlen_keep_F = int(F * (1 - mask_f_prob))noise = torch.rand(N, F, device=x.device) # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is removeids_keep = ids_shuffle[:, :len_keep_F]#index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)x_masked = torch.gather(x, dim=1, index=index)x_masked = x_masked.permute(0,2,1,3) # N F' T' D => N T' F' D #x_masked = x_masked.reshape(N,len_keep*T,D)x_masked = x_masked.reshape(N,len_keep_F*len_keep_T,D)return x_masked, None, Nonedef forward_features_mask(self, x, mask_t_prob, mask_f_prob):B = x.shape[0] #4,1,1024,128x = self.patch_embed(x) # 4, 512, 768x = x + self.pos_embed[:, 1:, :]if self.random_masking_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_t_prob)cls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1) x = self.pos_drop(x)# apply Transformer blocksfor blk in self.blocks:x = blk(x)if self.global_pool:x = x[:, 1:, :].mean(dim=1) # global pool without cls tokenoutcome = self.fc_norm(x)else:x = self.norm(x)outcome = x[:, 0]return outcome# overwrite original timmdef forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):if mask_t_prob > 0.0 or mask_f_prob > 0.0:x = self.forward_features_mask(x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)else:x = self.forward_features(x)x = self.head(x)return xdef vit_small_patch16(**kwargs):model = VisionTransformer(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return modeldef vit_base_patch16(**kwargs):model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef vit_large_patch16(**kwargs):model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef vit_huge_patch14(**kwargs):model = VisionTransformer(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return model
1. 整体结构
代码继承自timm
库(PyTorch Image Models)的VisionTransformer
类,在此基础上扩展了功能,主要包括:
- 支持全局平均池化(替代传统的cls token)
- 实现了1D和2D两种掩码机制(用于自监督学习)
- 支持自定义patch划分方式
- 提供了不同规模的ViT模型实例化方法(small/base/large/huge)
2. 核心类:VisionTransformer
该类继承自timm.models.vision_transformer.VisionTransformer
,并重写/扩展了关键方法。
2.1 构造函数 __init__
__init__
构造函数是 VisionTransformer
类的初始化方法,它在父类 timm.models.vision_transformer.VisionTransformer
的基础上扩展了新功能,主要用于配置模型的核心参数和组件。以下是详细说明:
函数定义
def __init__(self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs):super(VisionTransformer, self).__init__(** kwargs) # 调用父类构造函数# 自定义参数初始化self.global_pool = global_poolif self.global_pool:norm_layer = kwargs['norm_layer']embed_dim = kwargs['embed_dim']self.fc_norm = norm_layer(embed_dim) # 全局池化后的归一化层del self.norm # 移除父类的归一化层self.mask_2d = mask_2d # 是否启用2D掩码机制self.use_custom_patch = use_custom_patch # 是否使用自定义patch划分# 以下参数未实际使用(可能是调试残留或示例)num_heads=12depth=12mlp_ratio=4
核心参数解析
1. 输入参数
-
global_pool
(bool,默认False
):- 功能:控制模型最终的特征聚合方式。
- 若为
True
:使用「全局平均池化」替代传统ViT的「cls token」作为特征聚合方式。此时会创建self.fc_norm
作为全局池化后的归一化层。 - 若为
False
:沿用传统ViT的方式,使用cls token聚合特征。
-
mask_2d
(bool,默认True
):- 功能:指定掩码机制的类型。
- 若为
True
:使用random_masking_2d
方法(针对二维结构数据,如时间-频率谱图)。 - 若为
False
:使用random_masking
方法(针对一维序列数据)。
-
use_custom_patch
(bool,默认False
):- 功能:控制是否使用自定义的patch划分维度(影响
random_masking_2d
中的时间/频率维度设置)。 - 若为
True
:random_masking_2d
中会使用自定义的T
(时间维度)和F
(频率维度)值(如代码中注释的T=101, F=12
)。 - 若为
False
:使用默认的T
和F
值(如T=64, F=8
)。
- 功能:控制是否使用自定义的patch划分维度(影响
-
**kwargs
(可变参数):- 功能:接收父类
VisionTransformer
所需的全部参数(如patch_size
、embed_dim
、depth
等),并传递给父类构造函数。 - 父类关键参数包括:
patch_size
(patch大小)、embed_dim
(嵌入维度)、depth
(Transformer层数)、num_heads
(注意力头数)、norm_layer
(归一化层类型)等。
- 功能:接收父类
2. 关键操作解析
-
调用父类构造函数:
super(VisionTransformer, self).__init__(**kwargs)
这行代码初始化了父类
timm.models.vision_transformer.VisionTransformer
的所有基础组件,包括:patch_embed
:将输入图像分块并映射到特征维度的模块。pos_embed
:位置嵌入(包含cls token的位置)。cls_token
:分类标记(用于特征聚合)。blocks
:Transformer块的列表。norm
:父类默认的归一化层(后续可能被删除)。head
:最终的分类头。
-
全局池化相关配置:
if self.global_pool:norm_layer = kwargs['norm_layer'] # 从kwargs中获取归一化层类型(如nn.LayerNorm)embed_dim = kwargs['embed_dim'] # 获取特征嵌入维度self.fc_norm = norm_layer(embed_dim) # 创建全局池化后的归一化层 del self.norm # 移除父类的默认归一化层
- 当启用
global_pool
时,模型不再需要父类的norm
层(因为全局池化后使用fc_norm
归一化),因此显式删除self.norm
。 - 若不启用
global_pool
,虽然代码仍会执行del self.norm
,但后续forward_features
中会发现self.norm
已被删除,可能存在逻辑漏洞(需结合实际使用场景判断)。
- 当启用
-
冗余参数:
代码末尾的num_heads=12
、depth=12
、mlp_ratio=4
未被实际使用,可能是调试残留或示例代码,不影响模型功能。
总结
__init__
构造函数的核心作用是:
- 复用父类
VisionTransformer
的基础组件(patch嵌入、Transformer块等)。 - 新增
global_pool
、mask_2d
、use_custom_patch
三个关键参数,支持:- 两种特征聚合方式(cls token/全局池化)。
- 两种掩码机制(1D/2D)。
- 自定义patch维度(适应不同数据集)。
- 根据
global_pool
动态调整归一化层(删除父类的norm
,新增fc_norm
)。
这些扩展使模型更灵活,尤其适合处理二维结构数据(如音频频谱图)和自监督学习任务。
2.2 特征提取:forward_features
forward_features
方法是 VisionTransformer
类的核心特征提取逻辑,负责将输入数据通过一系列处理(分块嵌入、位置编码、Transformer编码等)转换为高层特征。它是模型从原始输入到最终特征输出的关键流程,以下是详细说明:
函数定义
def forward_features(self, x):B = x.shape[0] # 获取批次大小(batch size)x = self.patch_embed(x) # 1. 将输入分块并嵌入到特征空间x = x + self.pos_embed[:, 1:, :] # 2. 加入位置嵌入(排除cls token的位置)# 3. 处理cls token并与patch特征拼接cls_token = self.cls_token + self.pos_embed[:, :1, :] # cls token加位置嵌入cls_tokens = cls_token.expand(B, -1, -1) # 扩展到整个批次x = torch.cat((cls_tokens, x), dim=1) # 拼接cls token和patch特征序列x = self.pos_drop(x) # 4. 位置dropout(防止过拟合)# 5. 通过所有Transformer块进行特征编码for blk in self.blocks:x = blk(x)# 6. 特征聚合(根据配置选择cls token或全局池化)if self.global_pool:x = x[:, 1:, :].mean(dim=1) # 全局平均池化(排除cls token)outcome = self.fc_norm(x) # 全局池化后的归一化else:x = self.norm(x) # 传统归一化(使用父类的norm层,若未被删除)outcome = x[:, 0] # 取cls token作为最终特征return outcome # 返回提取的高层特征
逐步骤解析
1. 输入与批次大小
x
是输入数据,形状通常为[B, C, H, W]
(B
:批次大小,C
:通道数,H
/W
:高度/宽度),例如图像或频谱图数据。B = x.shape[0]
获取批次大小,用于后续扩展cls_token
到整个批次。
2. Patch嵌入(self.patch_embed(x)
)
- 功能:将输入的二维数据(如图像)分割为固定大小的「patch」(块),并将每个patch映射到高维特征空间。
- 处理过程:
- 例如,输入为
[B, 3, 224, 224]
(3通道图像,224×224分辨率),若patch_size=16
,则会被分割为(224/16)×(224/16)=14×14=196
个patch。 - 每个patch通过线性投影(或卷积)映射到
embed_dim
维度(如768),输出形状为[B, 196, 768]
(B
:批次,196
:patch数量,768
:特征维度)。
- 例如,输入为
- 对应组件:
self.patch_embed
是父类初始化的PatchEmbed
实例,负责分块和嵌入。
3. 位置嵌入(x = x + self.pos_embed[:, 1:, :]
)
- 功能:Transformer本身是无位置信息的,通过加入位置嵌入(positional embedding)让模型感知patch的空间位置。
- 细节:
self.pos_embed
是预定义的位置嵌入参数,形状为[1, L+1, D]
(L
:patch总数,1
:cls token的位置,D
:embed_dim
)。self.pos_embed[:, 1:, :]
表示取所有patch的位置嵌入(排除cls token的位置),与patch特征x
相加,输出形状仍为[B, L, D]
。
4. CLS Token处理与拼接
- CLS Token:是一个可学习的向量(
self.cls_token
),形状为[1, 1, D]
,用于聚合整个序列的特征(类似句子分类中的「[CLS]」标记)。 - 处理过程:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
:给cls token加上对应的位置嵌入(self.pos_embed
的第0位)。cls_tokens = cls_token.expand(B, -1, -1)
:将cls token扩展到整个批次,形状变为[B, 1, D]
。x = torch.cat((cls_tokens, x), dim=1)
:将cls token与patch特征序列拼接,输出形状为[B, L+1, D]
(L+1
:cls token + L个patch)。
5. 位置Dropout(self.pos_drop(x)
)
- 功能:对拼接后的序列(cls token + patch特征 + 位置嵌入)应用dropout,防止模型过度依赖位置信息,增强泛化能力。
- 对应组件:
self.pos_drop
是父类初始化的dropout层(通常为nn.Dropout
)。
6. Transformer块编码(for blk in self.blocks: x = blk(x)
)
- 功能:通过多层Transformer块对序列特征进行深度编码,捕捉patch之间的长距离依赖关系。
- 细节:
self.blocks
是Transformer块的列表,数量由depth
参数指定(如12层)。- 每个Transformer块包含「多头自注意力」和「MLP」两个核心模块,通过残差连接和层归一化增强训练稳定性。
- 输入输出形状保持不变(
[B, L+1, D]
),但特征经过深层语义编码。
7. 特征聚合(最终特征提取)
根据 global_pool
参数选择不同的聚合方式:
-
方式1:全局平均池化(
global_pool=True
)x = x[:, 1:, :].mean(dim=1)
:对所有patch特征(排除cls token)在序列维度(dim=1
)上求平均,得到形状为[B, D]
的全局特征。outcome = self.fc_norm(x)
:通过fc_norm
归一化层(在__init__
中创建)进行归一化,输出最终特征。
-
方式2:CLS Token(
global_pool=False
)x = self.norm(x)
:对整个序列(cls token + patch)进行归一化(使用父类的self.norm
层)。outcome = x[:, 0]
:取cls token(序列的第0位)作为最终特征,形状为[B, D]
。
输入输出总结
- 输入:原始数据
x
,形状[B, C, H, W]
。 - 输出:高层特征
outcome
,形状[B, D]
(D=embed_dim
)。
核心作用
forward_features
实现了ViT的核心特征提取流程:从原始输入→分块嵌入→位置编码→Transformer编码→特征聚合。它将底层的视觉/频谱信息转换为可用于分类、检索等任务的高层语义特征,是模型性能的关键所在。
该方法的灵活性体现在支持两种特征聚合方式(cls token/全局池化),可根据任务需求选择更适合的方案(例如全局池化在某些场景下可能比cls token更鲁棒)。
2.3 1D随机掩码:random_masking
random_masking
方法是实现一维随机掩码机制的核心函数,主要用于自监督学习中(如掩码自编码器MAE)对输入序列进行随机掩盖,通过让模型重建被掩盖的部分来学习更鲁棒的特征表示。以下是详细说明:
函数定义
def random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape # batch, length, dimlen_keep = int(L * (1 - mask_ratio)) # 保留的序列长度noise = torch.rand(N, L, device=x.device) # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1) # 恢复原始顺序的索引# keep the first subsetids_keep = ids_shuffle[:, :len_keep] # 要保留的元素索引x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # 保留的特征# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore) # 恢复掩码到原始序列顺序return x_masked, mask, ids_restore
核心功能
对输入的序列特征进行随机掩码,具体包括:
- 随机选择一部分元素保留,其余元素被掩盖(不参与后续计算)。
- 生成掩码矩阵(标记哪些元素被保留/掩盖)。
- 生成恢复索引(用于后续将掩码后的序列还原到原始顺序)。
该方法通过「基于随机噪声排序」的方式实现随机选择,确保每个样本的掩码是独立且随机的。
参数与输入输出
-
输入:
x
:输入序列特征,形状为[N, L, D]
(N
:批次大小,L
:序列长度,D
:特征维度)。mask_ratio
:掩码比例(被掩盖的元素占比,如0.75
表示掩盖75%的元素)。
-
输出:
x_masked
:掩码后的序列特征,形状为[N, len_keep, D]
(len_keep
为保留的元素数量)。mask
:掩码矩阵,形状为[N, L]
(0
表示保留,1
表示掩盖)。ids_restore
:恢复索引,形状为[N, L]
(用于将掩码后的序列还原到原始顺序)。
逐步骤解析
1. 计算保留长度
N, L, D = x.shape # 解析输入序列的维度
len_keep = int(L * (1 - mask_ratio)) # 计算需要保留的元素数量
- 例如:若序列长度
L=100
,mask_ratio=0.75
,则len_keep=25
(保留25%的元素)。
2. 生成随机噪声并排序
noise = torch.rand(N, L, device=x.device) # 生成[N, L]的随机噪声(取值范围[0,1))
ids_shuffle = torch.argsort(noise, dim=1) # 对噪声按行排序,得到打乱的索引
- 作用:通过对随机噪声排序实现「随机选择」。噪声值越小的位置,在
ids_shuffle
中索引越靠前(优先被保留)。 - 示例:假设
noise
某行为[0.8, 0.2, 0.5]
,则ids_shuffle
为[1, 2, 0]
(对应噪声从小到大的索引)。
3. 生成恢复索引
ids_restore = torch.argsort(ids_shuffle, dim=1) # 对打乱的索引再排序,得到恢复原始顺序的索引
- 作用:记录掩码后的元素如何映射回原始序列位置,用于后续重建任务(如MAE中还原被掩盖的元素)。
- 示例:若
ids_shuffle = [1, 2, 0]
,则ids_restore = [2, 0, 1]
(表示原始索引0在打乱后位于位置2,原始索引1位于位置0, etc.)。
4. 提取保留的特征
ids_keep = ids_shuffle[:, :len_keep] # 取前len_keep个索引(噪声最小的元素)
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
ids_keep.unsqueeze(-1).repeat(1, 1, D)
:将[N, len_keep]
的索引扩展为[N, len_keep, D]
,与输入x
的维度匹配。torch.gather(x, dim=1, index=...)
:沿序列维度(dim=1
)提取保留的元素,得到掩码后的特征x_masked
。
5. 生成掩码矩阵
mask = torch.ones([N, L], device=x.device) # 初始化全为1的掩码(1表示掩盖)
mask[:, :len_keep] = 0 # 前len_keep个位置设为0(0表示保留)
mask = torch.gather(mask, dim=1, index=ids_restore) # 将掩码恢复到原始序列顺序
- 先构造与
ids_shuffle
顺序一致的掩码(前len_keep
个为0),再通过ids_restore
映射回原始序列的顺序,最终mask
中每个位置的值表示该位置是否被掩盖(1=掩盖,0=保留)。
核心特点
- 逐样本随机:每个样本的掩码是独立生成的,避免样本间的掩码相关性。
- 无偏选择:通过均匀分布的随机噪声实现无偏随机选择,确保每个元素被保留的概率相同。
- 可恢复性:通过
ids_restore
实现掩码后序列到原始序列的映射,为后续重建任务提供支持。
应用场景
该方法主要用于自监督预训练,例如:
- 掩码自编码器(MAE):掩盖大部分输入序列,让模型学习从少量保留元素中重建原始序列,从而学习数据的内在结构。
- 对比学习:通过不同掩码策略生成样本的不同视图,用于对比学习任务。
在代码中,当 mask_2d=False
时,forward_features_mask
会调用该方法对序列进行一维掩码处理。
2.4 2D随机掩码:random_masking_2d
random_masking_2d
方法是针对二维结构数据(如音频频谱图、视频帧序列等具有时间-空间或时间-频率维度的数据)设计的掩码机制。与一维掩码(random_masking
)对序列进行无差别随机掩盖不同,它会分别对数据的两个维度(如时间维度 T
和频率维度 F
)进行掩码,更贴合二维数据的内在结构。以下是详细说明:
函数定义
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""2D: Spectrogram (masking t and f under mask_t_prob and mask_f_prob)Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape # batch, length, dim(L = T*F,即二维结构展开后的序列长度)if self.use_custom_patch:# 自定义patch划分:根据数据集设置时间(T)和频率(F)维度大小T=101; F=12 # 例如AS(AudioSet)数据集的参数# 其他数据集可选参数:ESC数据集 T=50, F=12;SPC数据集 T=12, F=12else:# 默认patch划分T=64; F=8 # 例如默认AS数据集参数# 其他数据集可选参数:ESC T=32, F=8;SPC T=8, F=8# 第一步:时间维度(T)掩码x = x.reshape(N, T, F, D) # 将序列重塑为二维结构 [N, T, F, D](时间×频率×特征)len_keep_T = int(T * (1 - mask_t_prob)) # 时间维度保留的长度noise = torch.rand(N, T, device=x.device) # 生成时间维度的随机噪声 [N, T]ids_shuffle = torch.argsort(noise, dim=1) # 对噪声排序,得到时间维度的打乱索引ids_keep = ids_shuffle[:, :len_keep_T] # 保留前len_keep_T个时间索引# 构造索引并提取保留的时间维度特征index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) # 扩展索引维度以匹配xx = torch.gather(x, dim=1, index=index) # 时间维度掩码后:[N, len_keep_T, F, D]# 第二步:频率维度(F)掩码x = x.permute(0, 2, 1, 3) # 转置为 [N, F, len_keep_T, D](方便频率维度处理)len_keep_F = int(F * (1 - mask_f_prob)) # 频率维度保留的长度noise = torch.rand(N, F, device=x.device) # 生成频率维度的随机噪声 [N, F]ids_shuffle = torch.argsort(noise, dim=1) # 对噪声排序,得到频率维度的打乱索引ids_keep = ids_shuffle[:, :len_keep_F] # 保留前len_keep_F个频率索引# 构造索引并提取保留的频率维度特征index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) # 扩展索引维度x_masked = torch.gather(x, dim=1, index=index) # 频率维度掩码后:[N, len_keep_F, len_keep_T, D]# 恢复维度顺序并重塑为序列x_masked = x_masked.permute(0, 2, 1, 3) # 转置回 [N, len_keep_T, len_keep_F, D]x_masked = x_masked.reshape(N, len_keep_F*len_keep_T, D) # 展开为序列 [N, len_keep_T*len_keep_F, D]return x_masked, None, None # 暂不返回mask和ids_restore(可扩展用于重建)
核心功能
针对二维结构数据(如频谱图的“时间-频率”维度),分两步进行掩码:
- 对时间维度(
T
)按mask_t_prob
比例随机掩盖,保留部分时间片段。 - 对频率维度(
F
)按mask_f_prob
比例随机掩盖,保留部分频率成分。
最终将掩码后的二维结构重新展开为序列,用于后续模型输入。
参数与输入输出
-
输入:
x
:输入序列特征,形状为[N, L, D]
(N
:批次大小,L=T*F
:二维结构展开后的序列长度,D
:特征维度)。mask_t_prob
:时间维度的掩码比例(如0.5
表示掩盖50%的时间片段)。mask_f_prob
:频率维度的掩码比例(如0.3
表示掩盖30%的频率成分)。
-
输出:
x_masked
:掩码后的序列特征,形状为[N, len_keep_T*len_keep_F, D]
(len_keep_T
/len_keep_F
分别为时间/频率维度保留的长度)。None, None
:暂未实现掩码矩阵(mask
)和恢复索引(ids_restore
),可扩展用于后续重建任务。
逐步骤解析
1. 解析输入与维度设置
N, L, D = x.shape # 解析输入维度:N=批次,L=序列长度(T*F),D=特征维度
# 根据use_custom_patch设置时间(T)和频率(F)维度的原始大小
if self.use_custom_patch:T=101; F=12 # 自定义patch划分(如AS数据集)
else:T=64; F=8 # 默认patch划分
- 关键:
L
必须等于T*F
(即输入序列是二维结构[T, F]
展开的结果),否则重塑会出错。 T
和F
的值需与实际数据的时间-频率维度匹配(如频谱图的时间步数和频率 bins 数)。
2. 时间维度(T)掩码
x = x.reshape(N, T, F, D) # 重塑为二维结构:[N, T, F, D](时间×频率×特征)
len_keep_T = int(T * (1 - mask_t_prob)) # 计算时间维度保留的长度
# 生成随机噪声并排序,选择保留的时间索引
noise = torch.rand(N, T, device=x.device) # [N, T] 的随机噪声([0,1))
ids_shuffle = torch.argsort(noise, dim=1) # 对噪声排序,得到时间维度的打乱索引
ids_keep = ids_shuffle[:, :len_keep_T] # 保留前len_keep_T个时间索引(噪声最小的位置)
# 提取保留的时间维度特征
index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) # 扩展索引为 [N, len_keep_T, F, D]
x = torch.gather(x, dim=1, index=index) # 沿时间维度(dim=1)提取,结果:[N, len_keep_T, F, D]
- 目的:随机保留部分时间片段,掩盖其余时间维度的信息(如在音频频谱图中,掩盖某些时间段的频谱)。
3. 频率维度(F)掩码
x = x.permute(0, 2, 1, 3) # 转置为 [N, F, len_keep_T, D](将频率维度放到dim=1,方便处理)
len_keep_F = int(F * (1 - mask_f_prob)) # 计算频率维度保留的长度
# 生成随机噪声并排序,选择保留的频率索引
noise = torch.rand(N, F, device=x.device) # [N, F] 的随机噪声([0,1))
ids_shuffle = torch.argsort(noise, dim=1) # 对噪声排序,得到频率维度的打乱索引
ids_keep = ids_shuffle[:, :len_keep_F] # 保留前len_keep_F个频率索引(噪声最小的位置)
# 提取保留的频率维度特征
index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) # 扩展索引为 [N, len_keep_F, len_keep_T, D]
x_masked = torch.gather(x, dim=1, index=index) # 沿频率维度(dim=1)提取,结果:[N, len_keep_F, len_keep_T, D]
- 目的:在时间维度掩码的基础上,进一步随机保留部分频率成分(如在音频频谱图中,掩盖某些频率范围的信息)。
4. 重塑为序列输出
x_masked = x_masked.permute(0, 2, 1, 3) # 转置回 [N, len_keep_T, len_keep_F, D](时间×频率×特征)
x_masked = x_masked.reshape(N, len_keep_F*len_keep_T, D) # 展开为序列:[N, len_keep_T*len_keep_F, D]
- 最终将二维结构重新展开为一维序列,以适应Transformer对序列输入的要求。
核心特点
- 维度感知掩码:区分时间和频率维度分别进行掩码,更贴合二维数据(如频谱图)的物理结构,保留了有意义的局部时空/频域相关性。
- 灵活配置:通过
use_custom_patch
支持不同数据集的维度参数(T
和F
),适配多样化的二维数据。 - 分步掩码:先时间后频率的两步掩码策略,可独立控制两个维度的掩码强度(通过
mask_t_prob
和mask_f_prob
)。
应用场景
主要用于二维结构数据的自监督学习,例如:
- 音频频谱图处理:对音频的梅尔频谱图(时间-频率维度)进行掩码,让模型学习从部分时频信息中重建完整频谱,提升音频分类/检索性能。
- 视频帧序列:对视频的“时间-空间”维度进行掩码(如时间上掩盖部分帧,空间上掩盖部分区域),学习视频的时序和空间特征。
在代码中,当 mask_2d=True
时,forward_features_mask
会调用该方法对二维结构数据进行掩码处理。
2.5 带掩码的特征提取:forward_features_mask
forward_features_mask
方法是该 Vision Transformer 模型中用于带掩码的特征提取的核心逻辑,主要服务于自监督学习场景(如掩码自编码任务)。它在常规特征提取流程(forward_features
)的基础上,加入了掩码操作,通过随机掩盖部分输入特征并让模型学习处理剩余特征,从而增强模型对数据本质结构的理解。以下是详细说明:
函数定义
def forward_features_mask(self, x, mask_t_prob, mask_f_prob):B = x.shape[0] # 获取批次大小(batch size)x = self.patch_embed(x) # 1. 将输入分块并嵌入到特征空间x = x + self.pos_embed[:, 1:, :] # 2. 加入位置嵌入(排除cls token的位置)# 3. 根据配置选择1D或2D掩码机制if self.mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_t_prob)# 4. 处理cls token并与掩码后的patch特征拼接cls_token = self.cls_token + self.pos_embed[:, :1, :] # cls token加位置嵌入cls_tokens = cls_token.expand(B, -1, -1) # 扩展到整个批次x = torch.cat((cls_tokens, x), dim=1) # 拼接cls token和掩码后的patch特征序列x = self.pos_drop(x) # 5. 位置dropout(防止过拟合)# 6. 通过所有Transformer块进行特征编码for blk in self.blocks:x = blk(x)# 7. 特征聚合(根据配置选择cls token或全局池化)if self.global_pool:x = x[:, 1:, :].mean(dim=1) # 全局平均池化(排除cls token)outcome = self.fc_norm(x) # 全局池化后的归一化else:x = self.norm(x) # 传统归一化(使用父类的norm层)outcome = x[:, 0] # 取cls token作为最终特征return outcome # 返回提取的高层特征
核心功能
该方法在常规特征提取流程中插入了掩码操作,具体流程为:输入数据→patch嵌入→位置编码→随机掩码→拼接cls token→Transformer编码→特征聚合。其核心目的是让模型在只有部分输入特征(未被掩码的部分)的情况下学习有效特征,从而提升模型的泛化能力和特征表示能力(尤其适用于自监督预训练)。
参数与输入输出
-
输入:
x
:原始输入数据,形状通常为[B, C, H, W]
(如图像、频谱图等)。mask_t_prob
:时间维度(或1D序列)的掩码比例(如0.75
表示掩盖75%的时间/序列元素)。mask_f_prob
:频率维度的掩码比例(仅在mask_2d=True
时有效,如0.5
表示掩盖50%的频率元素)。
-
输出:
outcome
:经过掩码处理和Transformer编码后的高层特征,形状为[B, D]
(D=embed_dim
,特征维度)。
逐步骤解析
1. 输入与批次大小
x
为原始输入数据(如[B, 1, 1024, 128]
的音频频谱图),B = x.shape[0]
获取批次大小,用于后续扩展cls_token
。
2. Patch嵌入(self.patch_embed(x)
)
- 与
forward_features
相同,将输入数据分割为patch并映射到高维特征空间。例如,输入[B, 1, 1024, 128]
经过patch嵌入后可能变为[B, 512, 768]
(512
为patch总数,768
为特征维度)。
3. 位置嵌入(x = x + self.pos_embed[:, 1:, :]
)
- 为patch特征加入位置嵌入,让模型感知patch的空间/时间位置。输出形状仍为
[B, L, D]
(L
为patch总数)。
4. 随机掩码(核心差异点)
if self.mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
else:x, mask, ids_restore = self.random_masking(x, mask_t_prob)
- 关键操作:根据
self.mask_2d
选择掩码方式,对patch特征进行随机掩盖:- 若
mask_2d=True
:调用random_masking_2d
,对二维结构(如时间-频率)的patch分别按mask_t_prob
(时间掩码比例)和mask_f_prob
(频率掩码比例)进行掩码。 - 若
mask_2d=False
:调用random_masking
,对一维序列的patch按mask_t_prob
(整体掩码比例)进行掩码。
- 若
- 结果:
x
变为掩码后的特征序列,长度缩短(仅保留未被掩盖的patch),形状为[B, L_keep, D]
(L_keep
为保留的patch数量)。
5. CLS Token处理与拼接
- 与
forward_features
逻辑一致:为cls token加入位置嵌入,扩展到整个批次,并与掩码后的patch特征拼接,输出形状为[B, L_keep+1, D]
(+1
为cls token)。
6. 位置Dropout(self.pos_drop(x)
)
- 对拼接后的序列(cls token + 掩码后的patch特征)应用dropout,增强模型泛化能力。
7. Transformer块编码
- 掩码后的序列通过所有Transformer块(
self.blocks
)进行深度编码,捕捉剩余patch之间的依赖关系。输入输出形状保持[B, L_keep+1, D]
。
8. 特征聚合
- 与
forward_features
相同,根据self.global_pool
选择聚合方式:- 若
global_pool=True
:对掩码后保留的patch特征(排除cls token)做全局平均池化,经fc_norm
归一化后输出。 - 若
global_pool=False
:对序列做归一化后,取cls token作为最终特征。
- 若
与 forward_features
的核心差异
对比项 | forward_features | forward_features_mask |
---|---|---|
掩码操作 | 无(使用全部patch特征) | 有(随机掩盖部分patch特征) |
输入特征完整性 | 完整的patch序列 | 部分patch被掩盖的序列 |
主要用途 | 有监督训练/推理(使用全部信息) | 自监督预训练(模拟信息缺失场景) |
序列长度 | 固定(L+1 ,L 为总patch数) | 可变(L_keep+1 ,L_keep 为保留数) |
应用场景
该方法主要用于自监督预训练任务,例如:
- 掩码自编码器(MAE):通过掩盖大部分输入patch,让模型学习从少量保留信息中重建原始数据,从而学习数据的内在结构。
- 对比学习:通过不同掩码策略生成同一数据的不同视图,让模型学习视图间的一致性,提升特征判别能力。
在模型的 forward
方法中,当 mask_t_prob>0
或 mask_f_prob>0
时,会自动调用该方法进行带掩码的特征提取。
2.6 模型前向传播:forward
forward
方法是模型的入口函数,负责协调整个模型的前向计算流程,根据输入参数决定是否启用掩码机制,并最终输出模型的预测结果。它是连接输入数据与最终输出的核心桥梁,以下是详细说明:
函数定义
def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):if mask_t_prob > 0.0 or mask_f_prob > 0.0:x = self.forward_features_mask(x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)else:x = self.forward_features(x)x = self.head(x) # 最终分类头处理return x
核心功能
该方法根据掩码概率参数(mask_t_prob
和 mask_f_prob
)决定特征提取的路径:
- 当需要掩码(掩码概率>0)时,调用带掩码的特征提取方法(
forward_features_mask
)。 - 当不需要掩码(掩码概率=0)时,调用常规特征提取方法(
forward_features
)。
最终通过分类头(self.head
)输出预测结果,完成整个前向传播流程。
参数解析
-
输入参数:
x
:原始输入数据,形状通常为[B, C, H, W]
(B
:批次大小,C
:通道数,H/W
:高度/宽度),例如图像、音频频谱图等。v
:未使用的参数(可能是预留接口,用于扩展多模态输入等场景)。mask_t_prob
:时间维度(或1D序列)的掩码比例(默认0.0
,表示不掩码)。当>0
时,启用掩码机制。mask_f_prob
:频率维度的掩码比例(默认0.0
),仅在mask_2d=True
时有效,>0
时启用频率维度掩码。
-
输出:
- 模型的预测结果,形状通常为
[B, num_classes]
(num_classes
为分类任务的类别数)。
- 模型的预测结果,形状通常为
逐步骤解析
1. 决定特征提取路径
if mask_t_prob > 0.0 or mask_f_prob > 0.0:# 启用掩码:调用带掩码的特征提取x = self.forward_features_mask(x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
else:# 不启用掩码:调用常规特征提取x = self.forward_features(x)
- 逻辑:通过判断掩码概率是否大于0,选择对应的特征提取流程:
- 当需要掩码(如自监督预训练时):使用
forward_features_mask
,对输入数据进行随机掩码后再提取特征。 - 当不需要掩码(如常规训练/推理时):使用
forward_features
,直接对完整输入提取特征。
- 当需要掩码(如自监督预训练时):使用
- 特征提取结果:两种路径均输出形状为
[B, D]
的高层特征(D
为embed_dim
,特征维度)。
2. 分类头处理(self.head(x)
)
x = self.head(x)
- 功能:将高层特征映射到任务输出空间(如分类任务的类别概率)。
- 细节:
self.head
是在父类VisionTransformer
中初始化的分类头,通常为线性层(nn.Linear
),输入维度为embed_dim
,输出维度为任务的类别数(num_classes
)。 - 示例:若
embed_dim=768
,分类任务有1000个类别,则self.head
为nn.Linear(768, 1000)
,输出形状为[B, 1000]
。
3. 返回预测结果
return x
- 输出最终的预测结果(如分类概率分布),供后续计算损失(训练时)或直接使用(推理时)。
关键特性
- 双路径设计:通过掩码概率参数无缝切换「带掩码」和「无掩码」两种模式,兼顾自监督预训练(需要掩码)和常规任务(不需要掩码)的需求。
- 兼容性:保留了父类ViT的基本接口,同时扩展了掩码相关参数,不破坏原有使用习惯。
- 灵活性:支持独立控制时间和频率维度的掩码比例(
mask_t_prob
和mask_f_prob
),适配不同的自监督训练策略。
应用场景
- 自监督预训练:当设置
mask_t_prob>0
或mask_f_prob>0
时,模型进入掩码模式,用于训练模型从部分信息中学习数据结构(如MAE任务)。 - 有监督训练/推理:当
mask_t_prob=0
且mask_f_prob=0
时,模型使用完整输入进行特征提取,用于常规的分类、回归等任务。
总结
forward
方法是模型的总调度器,通过简单的条件判断协调两种特征提取路径,最终输出预测结果。它的设计既满足了自监督学习对掩码机制的需求,又保持了常规任务的兼容性,体现了模型在不同训练阶段和应用场景下的灵活性。
3. 模型实例化函数
定义了不同规模的ViT模型(参数与原始ViT一致):
vit_small_patch16
:小模型,patch大小16×16,嵌入维度384,12层Transformer,6个注意力头。vit_base_patch16
:基础模型,嵌入维度768,12层,12个注意力头。vit_large_patch16
:大模型,嵌入维度1024,24层,16个注意力头。vit_huge_patch14
:超大模型,patch大小14×14,嵌入维度1280,32层,16个注意力头。
4. 核心改进与应用场景
- 灵活的特征聚合:支持cls token或全局平均池化,适应不同任务需求。
- 二维掩码机制:针对时间-频率等二维结构数据(如音频、视频)优化,更符合数据的空间/时间相关性。
- 自监督学习支持:通过随机掩码实现类似MAE(Masked Autoencoder)的预训练任务,提升模型特征学习能力。
该代码可用于图像分类、音频频谱分析等任务,尤其适合需要自监督预训练的场景。
audioMAE
现在介绍audioMAE,先给出完整的代码。
from functools import partial
from json import encoderimport torch
import torch.nn as nn#from timm.models.vision_transformer import PatchEmbed, Block
from timm.models.vision_transformer import Block
from util.pos_embed import get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_flexible, get_1d_sincos_pos_embed_from_grid
from util.misc import concat_all_gather
from util.patch_embed import PatchEmbed_new, PatchEmbed_org
from timm.models.swin_transformer import SwinTransformerBlockclass MaskedAutoencoderViT(nn.Module):""" Masked Autoencoder with VisionTransformer backbone"""def __init__(self, img_size=224, patch_size=16, stride=10, in_chans=3,embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, audio_exp=False, alpha=0.0, temperature=.2, mode=0, contextual_depth=8,use_custom_patch=False, split_pos=False, pos_trainable=False, use_nce=False, beta=4.0, decoder_mode=0,mask_t_prob=0.6, mask_f_prob=0.5, mask_2d=False,epoch=0, no_shift=False,):super().__init__()self.audio_exp=audio_expself.embed_dim = embed_dimself.decoder_embed_dim = decoder_embed_dim# --------------------------------------------------------------------------# MAE encoder specificsif use_custom_patch:print(f'Use custom patch_emb with patch size: {patch_size}, stride: {stride}')self.patch_embed = PatchEmbed_new(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, stride=stride)else:self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)self.use_custom_patch = use_custom_patchnum_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))#self.split_pos = split_pos # not usefulself.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable) # fixed sin-cos embeddingself.encoder_depth = depthself.contextual_depth = contextual_depthself.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)# --------------------------------------------------------------------------# MAE decoder specificsself.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=pos_trainable) # fixed sin-cos embeddingself.no_shift=no_shiftself.decoder_mode = decoder_modeif self.use_custom_patch: # overlapped patches as in AST. Similar performance yet compute heavywindow_size= (6,6)feat_size = (102,12)else:window_size= (4,4)feat_size = (64,8) if self.decoder_mode == 1:decoder_modules = []for index in range(16):if self.no_shift:shift_size = (0,0)else:if (index % 2) == 0:shift_size = (0,0)else:shift_size = (2,0)#shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])decoder_modules.append(SwinTransformerBlock(dim=decoder_embed_dim,num_heads=16,feat_size=feat_size,window_size=window_size,shift_size=shift_size,mlp_ratio=mlp_ratio,drop=0.0,drop_attn=0.0,drop_path=0.0,extra_norm=False,sequential_attn=False,norm_layer=norm_layer, #nn.LayerNorm,))self.decoder_blocks = nn.ModuleList(decoder_modules) else:# Transfomerself.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(decoder_depth)])self.decoder_norm = norm_layer(decoder_embed_dim)self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch# --------------------------------------------------------------------------self.norm_pix_loss = norm_pix_lossself.patch_size=patch_sizeself.stride=stride# audio expsself.alpha = alphaself.T = temperatureself.mode = modeself.use_nce = use_nceself.beta = betaself.log_softmax=nn.LogSoftmax(dim=-1)self.mask_t_prob=mask_t_probself.mask_f_prob=mask_f_probself.mask_2d=mask_2dself.epoch = epochself.initialize_weights()def initialize_weights(self):# initialization# initialize (and freeze) pos_embed by sin-cos embeddingif self.audio_exp:pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True) else:pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))if self.audio_exp: decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True)else:decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))# initialize patch_embed like nn.Linear (instead of nn.Conv2d)w = self.patch_embed.proj.weight.datatorch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)torch.nn.init.normal_(self.cls_token, std=.02)torch.nn.init.normal_(self.mask_token, std=.02)# initialize nn.Linear and nn.LayerNormself.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):# we use xavier_uniform following official JAX ViT:torch.nn.init.xavier_uniform_(m.weight)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def patchify(self, imgs):"""imgs: (N, 3, H, W)x: (N, L, patch_size**2 *3)L = (H/p)*(W/p)"""p = self.patch_embed.patch_size[0]#assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0if self.audio_exp:if self.use_custom_patch: # overlapped patchh,w = self.patch_embed.patch_hw# todo: fixed h/w patch size and stride size. Make hw custom in the futurex = imgs.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride) # n,1,H,W -> n,1,h,w,p,px = x.reshape(shape=(imgs.shape[0], h*w, p**2 * 1))#x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))#x = torch.einsum('nchpwq->nhwpqc', x)#x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))else:h = imgs.shape[2] // pw = imgs.shape[3] // p#h,w = self.patch_embed.patch_hwx = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))else:h = w = imgs.shape[2] // px = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))return xdef unpatchify(self, x):"""x: (N, L, patch_size**2 *3)specs: (N, 1, H, W)"""p = self.patch_embed.patch_size[0] h = 1024//pw = 128//px = x.reshape(shape=(x.shape[0], h, w, p, p, 1))x = torch.einsum('nhwpqc->nchpwq', x)specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))return specsdef random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape # batch, length, dimlen_keep = int(L * (1 - mask_ratio))noise = torch.rand(N, L, device=x.device) # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoredef random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape # batch, length, dimif self.use_custom_patch: # overlapped patchT=101F=12else: T=64F=8#x = x.reshape(N, T, F, D)len_keep_t = int(T * (1 - mask_t_prob))len_keep_f = int(F * (1 - mask_f_prob))# noise for mask in timenoise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]# sort noise for each sample aling timeids_shuffle_t = torch.argsort(noise_t, dim=1) # ascend: small is keep, large is removeids_restore_t = torch.argsort(ids_shuffle_t, dim=1) ids_keep_t = ids_shuffle_t[:,:len_keep_t]# noise mask in freqnoise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]ids_shuffle_f = torch.argsort(noise_f, dim=1) # ascend: small is keep, large is removeids_restore_f = torch.argsort(ids_shuffle_f, dim=1) ids_keep_f = ids_shuffle_f[:,:len_keep_f] ## generate the binary mask: 0 is keep, 1 is remove# mask in freqmask_f = torch.ones(N, F, device=x.device)mask_f[:,:len_keep_f] = 0mask_f = torch.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1,T,1) # N,T,F# mask in timemask_t = torch.ones(N, T, device=x.device)mask_t[:,:len_keep_t] = 0mask_t = torch.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1,F,1).permute(0,2,1) # N,T,Fmask = 1-(1-mask_t)*(1-mask_f) # N, T, F# get masked xid2res=torch.Tensor(list(range(N*T*F))).reshape(N,T,F).to(x.device)id2res = id2res + 999*mask # add a large value for masked elementsid2res2 = torch.argsort(id2res.flatten(start_dim=1))ids_keep=id2res2.flatten(start_dim=1)[:,:len_keep_f*len_keep_t]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))ids_restore = torch.argsort(id2res2.flatten(start_dim=1))mask = mask.flatten(start_dim=1)return x_masked, mask, ids_restoredef forward_encoder(self, x, mask_ratio, mask_2d=False):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratioif mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blocksfor blk in self.blocks:x = blk(x)x = self.norm(x)#emb = self.encoder_emb(x)return x, mask, ids_restore, Nonedef forward_encoder_no_mask(self, x):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratio#x, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blockscontextual_embs=[]for n, blk in enumerate(self.blocks):x = blk(x)if n > self.contextual_depth:contextual_embs.append(self.norm(x))#x = self.norm(x)contextual_emb = torch.stack(contextual_embs,dim=0).mean(dim=0)return contextual_embdef forward_decoder(self, x, ids_restore):# embed tokensx = self.decoder_embed(x)# append mask tokens to sequencemask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls tokenx_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshufflex = torch.cat([x[:, :1, :], x_], dim=1) # append cls token# add pos embedx = x + self.decoder_pos_embedif self.decoder_mode != 0:B,L,D=x.shapex = x[:,1:,:]if self.use_custom_patch:x = x.reshape(B,101,12,D)x = torch.cat([x,x[:,-1,:].unsqueeze(1)],dim=1) # hackx = x.reshape(B,1224,D)if self.decoder_mode > 3: # mvitx = self.decoder_blocks(x)else:# apply Transformer blocksfor blk in self.decoder_blocks:x = blk(x)x = self.decoder_norm(x)# predictor projectionpred = self.decoder_pred(x)# remove cls tokenif self.decoder_mode != 0:if self.use_custom_patch:pred = pred.reshape(B,102,12,256)pred = pred[:,:101,:,:]pred = pred.reshape(B,1212,256)else:pred = predelse:pred = pred[:, 1:, :]return pred, None, None #emb, emb_pixeldef forward_loss(self, imgs, pred, mask, norm_pix_loss=False):"""imgs: [N, 3, H, W]pred: [N, L, p*p*3]mask: [N, L], 0 is keep, 1 is remove, """target = self.patchify(imgs)if norm_pix_loss:mean = target.mean(dim=-1, keepdim=True)var = target.var(dim=-1, keepdim=True)target = (target - mean) / (var + 1.e-6)**.5loss = (pred - target) ** 2loss = loss.mean(dim=-1) # [N, L], mean loss per patchloss = (loss * mask).sum() / mask.sum() # mean loss on removed patchesreturn loss def forward(self, imgs, mask_ratio=0.8):emb_enc, mask, ids_restore, _ = self.forward_encoder(imgs, mask_ratio, mask_2d=self.mask_2d)pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]loss_recon = self.forward_loss(imgs, pred, mask, norm_pix_loss=self.norm_pix_loss)loss_contrastive = torch.FloatTensor([0.0]).cuda()return loss_recon, pred, mask, loss_contrastivedef mae_vit_small_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=384, depth=12, num_heads=6,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_base_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=768, depth=12, num_heads=12,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_large_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return modeldef mae_vit_huge_patch14_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=14, embed_dim=1280, depth=32, num_heads=16,decoder_embed_dim=512, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return model# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
这段代码实现了一个基于Vision Transformer(ViT)的掩码自编码器(Masked Autoencoder, MAE),核心用于自监督学习——通过随机掩盖输入数据的部分区域,让模型学习从剩余区域重建完整数据,从而提取通用的特征表示。该模型特别适配了音频频谱图(如梅尔频谱)的处理(通过audio_exp
、2D掩码等参数),同时保留了对图像数据的兼容性。
一、整体框架与核心设计
MAE的核心逻辑是“编码-解码”结构:
- 编码器(Encoder):对输入数据分块(Patch),随机掩盖部分Patch后,通过Transformer提取特征。
- 解码器(Decoder):接收编码器输出的“未掩盖Patch特征”,拼接“掩码Token”,通过Transformer(或Swin Transformer)重建被掩盖的Patch。
- 损失计算:仅对被掩盖的Patch计算重建损失,迫使模型学习数据的内在结构。
该代码在标准MAE基础上扩展了音频适配(如重叠Patch、2D掩码)、灵活解码器(支持ViT/Swin解码器)、可配置位置嵌入等功能,下面分模块详细解析。
二、核心类:MaskedAutoencoderViT
2.1 构造函数 __init__
:初始化编码器/解码器组件
__init__
是模型的“骨架搭建”部分,定义了编码器、解码器的核心模块及超参数,参数多达20+,需按功能分组理解:
1. 基础配置参数
参数 | 功能说明 |
---|---|
img_size | 输入数据尺寸(如音频频谱图 1024×128 、图像 224×224 ) |
patch_size | Patch(数据块)的大小(如 16 表示 16×16 的Patch) |
stride | Patch划分的步长(仅自定义Patch时生效,用于生成重叠Patch) |
in_chans | 输入通道数(图像3通道,音频频谱图1通道) |
embed_dim | 编码器特征维度(如ViT-Base为768) |
depth | 编码器Transformer块数量 |
num_heads | 编码器多头注意力头数 |
decoder_embed_dim | 解码器特征维度(通常小于编码器,如512,降低计算量) |
decoder_depth | 解码器Transformer块数量 |
decoder_num_heads | 解码器多头注意力头数 |
mlp_ratio | Transformer块中MLP层的扩张比例(如4表示隐藏层维度是输入的4倍) |
norm_layer | 归一化层类型(默认nn.LayerNorm ) |
norm_pix_loss | 是否对像素进行归一化后计算损失(稳定训练) |
2. 音频适配参数
参数 | 功能说明 |
---|---|
audio_exp | 是否为音频实验(True时处理1通道频谱图,False处理3通道图像) |
use_custom_patch | 是否使用自定义Patch嵌入(PatchEmbed_new ),支持重叠Patch(适配音频) |
mask_t_prob /mask_f_prob | 2D掩码中“时间维度”和“频率维度”的掩码比例(仅音频频谱图生效) |
mask_2d | 是否启用2D掩码(区分时间/频率维度,而非1D无差别掩码) |
3. 训练与结构适配参数
参数 | 功能说明 |
---|---|
pos_trainable | 位置嵌入(Positional Embedding)是否可训练(默认固定为sin-cos嵌入) |
decoder_mode | 解码器类型(0=ViT解码器,1=Swin Transformer解码器,适配二维结构) |
no_shift | Swin解码器中是否禁用“移位窗口”(避免边界效应) |
use_nce | 是否启用对比损失(预留接口,当前未实现) |
alpha /temperature | 对比损失相关超参数(预留) |
4. 核心组件初始化
在 MaskedAutoencoderViT
类的构造函数(__init__
)中,核心组件的初始化围绕编码器(Encoder) 和解码器(Decoder) 两大模块展开,每个模块包含多个关键组件,共同支撑掩码自编码器的“编码-解码-重建”流程。以下是核心组件的详细初始化描述:
一、编码器(Encoder)组件初始化
编码器的作用是:将输入数据分割为Patch、添加位置信息、随机掩码部分Patch后,通过Transformer提取高层特征。其核心组件包括:
1. Patch嵌入层(self.patch_embed
)
- 功能:将输入的二维数据(图像或音频频谱图)分割为固定大小的Patch,并通过线性映射将每个Patch转换为高维特征(维度为
embed_dim
)。 - 初始化逻辑:
if use_custom_patch:self.patch_embed = PatchEmbed_new(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, stride=stride) else:self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
PatchEmbed_org
:普通Patch嵌入(无重叠),适用于图像等结构化数据。通过卷积或线性层将[H, W]
的输入按patch_size
分割为(H/patch_size)×(W/patch_size)
个Patch,映射到embed_dim
维度。PatchEmbed_new
:自定义Patch嵌入(支持重叠),适用于音频频谱图等需要保留时序/频域连续性的数据。通过stride
参数控制滑动步长(如patch_size=16
、stride=10
),生成重叠Patch,增强局部相关性。
- 关键参数:
img_size
(输入尺寸)、patch_size
(Patch大小)、stride
(滑动步长,仅自定义模式有效)、in_chans
(输入通道数)、embed_dim
(输出特征维度)。
2. 分类标记(self.cls_token
)
- 功能:一个可学习的向量,用于聚合编码器的全局特征(类似ViT中的
[CLS]
标记),辅助解码器重建全局信息。 - 初始化逻辑:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- 形状为
[1, 1, embed_dim]
( batch维度=1,序列长度=1,特征维度=embed_dim)。 - 初始化为全零向量,后续通过训练学习(在
initialize_weights
中用正态分布细化初始化:torch.nn.init.normal_(self.cls_token, std=.02)
)。
- 形状为
3. 编码器位置嵌入(self.pos_embed
)
- 功能:为Patch添加位置信息(Transformer本身无位置感知能力),使模型理解Patch的空间/时序位置关系。
- 初始化逻辑:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable)
- 形状:
[1, num_patches + 1, embed_dim]
,其中num_patches
是总Patch数(由self.patch_embed.num_patches
获取),+1
为cls_token
预留位置。 - 可训练性:由
pos_trainable
控制,True
则位置嵌入可通过训练更新;False
则使用固定的sin-cos嵌入(在initialize_weights
中初始化)。 - 位置嵌入生成:音频数据用
get_2d_sincos_pos_embed_flexible
,图像用get_2d_sincos_pos_embed
,确保与Patch的二维结构匹配。
- 形状:
4. Transformer编码器块(self.blocks
)
- 功能:通过多层Transformer块对Patch特征进行深度编码,捕捉Patch间的长距离依赖关系。
- 初始化逻辑:
self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(depth) ])
- 组成:
nn.ModuleList
包含depth
个Block
(来自timm.models.vision_transformer.Block
),每个Block
由“多头自注意力(Multi-Head Attention)”和“MLP”组成,通过残差连接和层归一化增强训练稳定性。 - 关键参数:
embed_dim
(输入特征维度)、num_heads
(注意力头数)、mlp_ratio
(MLP层扩张比例,如4表示隐藏层维度是输入的4倍)、norm_layer
(归一化层类型)。
- 组成:
5. 编码器归一化层(self.norm
)
- 功能:对编码器最后一层的输出进行归一化,稳定特征分布。
- 初始化逻辑:
self.norm = norm_layer(embed_dim)
- 使用
norm_layer
(默认nn.LayerNorm
),输入维度为embed_dim
,在_init_weights
中被初始化为权重=1、偏置=0。
- 使用
二、解码器(Decoder)组件初始化
解码器的作用是:接收编码器输出的“未掩码Patch特征”,拼接“掩码Token”后重建被掩码的Patch。其核心组件包括:
1. 解码器嵌入层(self.decoder_embed
)
- 功能:将编码器输出的特征(
embed_dim
维度)线性投影到解码器的特征维度(decoder_embed_dim
),匹配解码器输入要求。 - 初始化逻辑:
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
- 线性层(
nn.Linear
),输入维度embed_dim
,输出维度decoder_embed_dim
,带偏置。 - 在
_init_weights
中用Xavier均匀初始化权重,偏置初始化为0。
- 线性层(
2. 掩码Token(self.mask_token
)
- 功能:一个可学习的向量,用于填充被掩码的Patch位置(编码器未处理这些Patch,解码器需通过该Token重建)。
- 初始化逻辑:
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
- 形状为
[1, 1, decoder_embed_dim]
(与单个Patch的特征维度匹配)。 - 初始化为全零向量,训练中学习最优值(在
initialize_weights
中用正态分布细化:torch.nn.init.normal_(self.mask_token, std=.02)
)。
- 形状为
3. 解码器位置嵌入(self.decoder_pos_embed
)
- 功能:为解码器的Patch和
cls_token
添加位置信息,确保重建时Patch的位置对齐。 - 初始化逻辑:
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=pos_trainable)
- 形状:
[1, num_patches + 1, decoder_embed_dim]
,与编码器位置嵌入结构一致,但维度适配解码器(decoder_embed_dim
)。 - 初始化方式:同编码器位置嵌入(sin-cos嵌入或可训练),在
initialize_weights
中通过get_2d_sincos_pos_embed
或get_2d_sincos_pos_embed_flexible
生成。
- 形状:
4. 解码器Transformer块(self.decoder_blocks
)
- 功能:对解码器输入(未掩码Patch特征+掩码Token)进行编码,学习重建被掩码Patch的特征。
- 初始化逻辑:支持两种类型的解码器(由
decoder_mode
控制):if self.decoder_mode == 1:# Swin Transformer解码器(适配二维结构)decoder_modules = [SwinTransformerBlock(dim=decoder_embed_dim, num_heads=16, feat_size=feat_size,window_size=window_size, shift_size=shift_size, mlp_ratio=mlp_ratio,norm_layer=norm_layer) for index in range(16)]self.decoder_blocks = nn.ModuleList(decoder_modules) else:# 普通ViT解码器self.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for _ in range(decoder_depth)])
- 普通ViT解码器(
decoder_mode≠1
):decoder_depth
个标准Block
,参数为decoder_embed_dim
(特征维度)、decoder_num_heads
(注意力头数)等,与编码器结构一致。 - Swin Transformer解码器(
decoder_mode=1
):16个SwinTransformerBlock
(来自timm.models.swin_transformer
),适配二维数据(如音频频谱图的时间-频率维度):window_size
:注意力窗口大小(如(4,4)
),控制局部注意力范围。shift_size
:移位窗口步长(偶数层不移位,奇数层移位2
),避免重复计算并增强全局感受野。feat_size
:特征图尺寸(如(102,12)
),匹配输入Patch的二维结构。
- 普通ViT解码器(
5. 解码器归一化层(self.decoder_norm
)
- 功能:对解码器最后一层的输出进行归一化,稳定重建特征的分布。
- 初始化逻辑:
self.decoder_norm = norm_layer(decoder_embed_dim)
- 使用
norm_layer
,输入维度为decoder_embed_dim
,初始化方式同编码器归一化层。
- 使用
6. 解码器预测层(self.decoder_pred
)
- 功能:将解码器输出的特征(
decoder_embed_dim
维度)线性投影到“单个Patch的原始像素维度”,实现Patch重建。 - 初始化逻辑:
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
- 输出维度为
patch_size² × in_chans
(单个Patch的像素总数,如16×16×1=256
for 音频单通道)。 - 在
_init_weights
中用Xavier均匀初始化权重,偏置初始化为0。
- 输出维度为
三、核心组件的协同关系
编码器与解码器的组件通过“特征传递”和“位置对齐”协同工作:
- Patch嵌入与位置嵌入:
patch_embed
输出的Patch特征与pos_embed
相加,赋予位置信息。 - 掩码与Token拼接:编码器掩码后保留的Patch特征与
cls_token
拼接,输入blocks
编码;解码器通过mask_token
填充被掩码位置,与编码器输出对齐。 - 特征维度匹配:
decoder_embed
将编码器特征投影到decoder_embed_dim
,确保与解码器块输入维度一致。 - 重建映射:
decoder_pred
将解码器输出映射到Patch像素维度,最终通过unpatchify
还原为原始数据形状。
总结
构造函数通过初始化编码器的“Patch嵌入-位置编码-Transformer编码”组件和解码器的“特征投影-掩码填充-Transformer解码-重建预测”组件,搭建了MAE的完整架构。这些组件的设计充分考虑了音频/图像的多模态适配(重叠Patch、2D掩码)、特征维度匹配(embed_dim
与decoder_embed_dim
)和位置信息一致性(编码器/解码器位置嵌入),为自监督学习中的“掩码-重建”任务提供了基础。
2.2 权重初始化:initialize_weights
initialize_weights
方法是 MaskedAutoencoderViT
模型中权重初始化的核心逻辑,负责为模型所有可学习参数(如位置嵌入、Patch嵌入、Token向量、线性层等)设置合理的初始值。合理的权重初始化是模型训练稳定收敛的前提,尤其对于Transformer这类深度模型,不当的初始化可能导致梯度消失/爆炸或训练停滞。以下是该方法的详细解析:
方法定义与整体作用
def initialize_weights(self):# 初始化(并冻结)位置嵌入为sin-cos嵌入if self.audio_exp:pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True) else:pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))# 初始化解码器位置嵌入if self.audio_exp: decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True)else:decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))# 初始化Patch嵌入层(类似nn.Linear而非nn.Conv2d)w = self.patch_embed.proj.weight.datatorch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))# 初始化cls_token和mask_token(用timm的trunc_normal_等效实现)torch.nn.init.normal_(self.cls_token, std=.02)torch.nn.init.normal_(self.mask_token, std=.02)# 初始化所有nn.Linear和nn.LayerNorm层self.apply(self._init_weights)
核心作用:为模型的关键组件(位置嵌入、Patch嵌入、可学习Token、线性层、归一化层)设置初始权重,确保训练开始时参数分布合理,梯度流动稳定。
逐步骤解析
1. 编码器位置嵌入(self.pos_embed
)的初始化
位置嵌入(Positional Embedding)是Transformer理解序列位置关系的核心,MAE中默认使用固定的正弦余弦嵌入(而非随机初始化),原因是:
- 正弦余弦嵌入具有天然的位置连续性(距离近的位置嵌入相似),更适合捕捉空间/时序关系;
- 固定嵌入可减少训练参数,避免过拟合,尤其在自监督预训练阶段。
初始化逻辑:
if self.audio_exp:# 音频场景:使用灵活的2D正弦余弦嵌入(适配非正方形Patch)pos_embed = get_2d_sincos_pos_embed_flexible(embed_dim=self.pos_embed.shape[-1], # 嵌入维度(如768)grid_size=self.patch_embed.patch_hw, # Patch的二维尺寸(如T=101, F=12)cls_token=True # 预留cls_token的位置(嵌入长度+1))
else:# 图像场景:使用标准2D正弦余弦嵌入(正方形Patch)pos_embed = get_2d_sincos_pos_embed(embed_dim=self.pos_embed.shape[-1],grid_size=int(self.patch_embed.num_patches** .5), # 正方形网格尺寸(如14=224/16)cls_token=True)
# 将生成的numpy数组转换为Tensor,复制到pos_embed参数中
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
get_2d_sincos_pos_embed
:生成正方形网格的位置嵌入,适用于图像(如14×14
的Patch网格)。get_2d_sincos_pos_embed_flexible
:生成非正方形网格的位置嵌入,适用于音频频谱图(如101×12
的时间-频率Patch网格)。- 最终嵌入形状为
[1, num_patches+1, embed_dim]
(+1为cls_token预留位置),通过data.copy_
赋值,若pos_trainable=False
(默认),则后续训练中不更新。
2. 解码器位置嵌入(self.decoder_pos_embed
)的初始化
解码器位置嵌入的作用与编码器类似,用于对齐解码器中Patch的位置关系,确保重建时空间/时序一致性。
初始化逻辑:与编码器位置嵌入完全对称,仅嵌入维度为decoder_embed_dim
(解码器特征维度),代码逻辑相同:
if self.audio_exp: decoder_pos_embed = get_2d_sincos_pos_embed_flexible(...)
else:decoder_pos_embed = get_2d_sincos_pos_embed(...)
self.decoder_pos_embed.data.copy_(...)
3. Patch嵌入层(self.patch_embed.proj
)的初始化
Patch嵌入层(proj
)负责将输入数据的Patch映射到高维特征空间(如patch_embed.proj
是卷积层或线性层),其初始化需确保映射前后的信号方差一致。
初始化逻辑:
w = self.patch_embed.proj.weight.data # 获取卷积/线性层的权重
# 将权重重塑为[输出维度, 输入维度](忽略空间维度),用Xavier均匀初始化
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- Xavier均匀初始化:适用于线性映射(如Patch嵌入本质是线性投影),通过使权重分布的方差与输入/输出维度适配,确保前向传播和反向传播中信号的方差一致,避免梯度消失/爆炸。
- 无论
proj
是卷积层(nn.Conv2d
)还是线性层(nn.Linear
),均视为“输入Patch的像素→输出特征”的线性映射,因此重塑后按线性层方式初始化。
4. 可学习Token(cls_token
和mask_token
)的初始化
cls_token
(分类Token)和mask_token
(掩码Token)是可学习的向量,需初始化为较小的随机值,避免初始时对特征产生过大影响。
初始化逻辑:
# 用标准差0.02的正态分布初始化(等效于timm的trunc_normal_,截断值较大时近似正态分布)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
cls_token
用于聚合编码器全局特征,mask_token
用于填充被掩码的Patch位置,初始值过小会导致初始影响弱,过大则可能主导特征,0.02的标准差是视觉Transformer中常用的经验值。
5. 线性层与归一化层的初始化(self._init_weights
)
通过self.apply(self._init_weights)
对模型中所有子模块递归应用初始化,重点处理nn.Linear
和nn.LayerNorm
:
def _init_weights(self, m):if isinstance(m, nn.Linear):# 线性层:Xavier均匀初始化权重,偏置设为0torch.nn.init.xavier_uniform_(m.weight)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):# 归一化层:权重设为1,偏置设为0(初始不改变输入分布)nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)
- 线性层(
nn.Linear
):如decoder_embed
(编码器→解码器投影)、decoder_pred
(解码器→Patch重建)等,Xavier初始化确保特征映射的稳定性,偏置初始化为0避免引入额外偏移。 - 归一化层(
nn.LayerNorm
):如self.norm
(编码器输出归一化)、self.decoder_norm
(解码器输出归一化)等,初始化为“权重=1,偏置=0”,确保初始时不改变输入特征的分布(仅在训练中学习调整)。
关键设计原则
- 位置嵌入的固定性:默认使用sin-cos嵌入而非可学习嵌入,利用其天然的位置连续性,尤其适合音频/图像的空间/时序结构。
- 线性映射的一致性:Patch嵌入和线性层均使用Xavier初始化,确保特征在映射过程中方差稳定,避免梯度问题。
- 可学习Token的弱初始化:
cls_token
和mask_token
初始为小随机值,让模型在训练中自主学习最优表示,避免初始主导特征。 - 归一化层的中性初始化:初始不改变输入分布,确保训练初期特征的自然演化。
总结
initialize_weights
方法通过针对性的初始化策略,为MAE的各核心组件(位置嵌入、Patch嵌入、可学习Token、线性层、归一化层)设置了合理的初始权重。这些策略兼顾了Transformer的结构特性(对位置敏感、依赖线性映射)和自监督学习的需求(稳定训练、捕捉数据内在结构),为模型后续的掩码重建任务奠定了基础。
2.3 Patch处理:patchify
与unpatchify
patchify
与unpatchify
是掩码自编码器(MAE)中连接“原始输入数据”与“模型处理的Patch序列”的核心方法:
patchify
:将原始二维数据(图像或音频频谱图)分割为固定大小的Patch,转换为模型可处理的序列格式。unpatchify
:将模型输出的Patch序列还原为原始数据形状,用于计算重建损失(对比原始数据与重建结果)。
这两个方法是MAE“掩码-重建”逻辑的基础,确保输入数据能被模型处理,且输出能被还原为原始格式进行损失计算。以下是详细解析:
一、patchify
:将原始数据分割为Patch序列
功能
将形状为 [N, C, H, W]
的原始数据(N
:批次大小,C
:通道数,H/W
:高度/宽度)分割为 L
个Patch(L = h × w
,h
/w
为Patch的行数/列数),输出形状为 [N, L, p²×C]
的序列(p
为Patch大小,p²×C
为单个Patch的像素/特征维度)。
方法定义
def patchify(self, imgs):"""imgs: (N, C, H, W) # 原始输入数据x: (N, L, patch_size**2 * C) # 输出的Patch序列,L = h*w"""p = self.patch_embed.patch_size[0] # Patch大小(假设正方形Patch)if self.audio_exp: # 处理音频频谱图(1通道)if self.use_custom_patch: # 重叠Patch(自定义划分)h, w = self.patch_embed.patch_hw # Patch的行数和列数(如101×12)# 用unfold生成重叠Patch:在H和W维度按stride滑动窗口x = imgs.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride)# 形状转换:[N, 1, h, w, p, p] → [N, h*w, p²×1]x = x.reshape(shape=(imgs.shape[0], h*w, p**2 * 1))else: # 非重叠Patch(默认划分)h = imgs.shape[2] // p # 行数 = 高度//Patch大小w = imgs.shape[3] // p # 列数 = 宽度//Patch大小# 形状转换:[N, 1, H, W] → [N, 1, h, p, w, p] → 调整维度顺序 → [N, h*w, p²×1]x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x) # 调整维度顺序,将Patch内像素放在最后x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))else: # 处理图像(3通道)h = w = imgs.shape[2] // p # 图像通常为正方形,h=w# 形状转换:[N, 3, H, W] → [N, 3, h, p, w, p] → 调整维度顺序 → [N, h*w, p²×3]x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))return x
关键逻辑与分情况解析
patchify
的处理逻辑根据数据类型(音频/图像)和Patch划分方式(重叠/非重叠)有所不同,核心是将原始数据的空间维度(H, W)转换为“Patch数量×单个Patch特征”的序列维度。
1. 音频频谱图处理(self.audio_exp=True
)
音频频谱图通常为单通道(C=1
),且具有时间(H)和频率(W)的二维结构,需保留时序/频域连续性,因此支持重叠Patch。
-
重叠Patch(
self.use_custom_patch=True
):
适用于需要保留局部连续性的场景(如音频时序特征),通过torch.Tensor.unfold
实现滑动窗口分割:unfold(dim, size, step)
:在指定维度(dim=2
为H,dim=3
为W)上,以size=self.patch_size
为窗口大小,step=self.stride
为滑动步长,生成重叠窗口。- 例如:输入频谱图
[N, 1, 1024, 128]
,patch_size=16
,stride=10
,则H维度生成(1024-16)/10 + 1 = 101
个窗口,W维度生成(128-16)/10 + 1 = 12
个窗口,最终得到h=101
、w=12
个Patch,形状为[N, 1, 101, 12, 16, 16]
,reshape后为[N, 101×12, 16²×1] = [N, 1212, 256]
。
-
非重叠Patch(
self.use_custom_patch=False
):
适用于频谱图尺寸能被patch_size
整除的场景,通过reshape
和einsum
分割:- 先将
[N, 1, H, W]
拆分为[N, 1, h, p, w, p]
(h=H//p
,w=W//p
),再用torch.einsum('nchpwq->nhwpqc'
调整维度顺序(将Patch内的p×p
像素放在最后),最后合并为[N, h×w, p²×1]
。
- 先将
2. 图像处理(self.audio_exp=False
)
图像通常为3通道(C=3
),空间结构规则,采用非重叠Patch:
- 输入
[N, 3, H, W]
(如[2, 3, 224, 224]
),patch_size=16
,则h=w=224//16=14
,共14×14=196
个Patch。 - 先拆分为
[N, 3, 14, 16, 14, 16]
,通过einsum
调整维度顺序为[N, 14, 14, 16, 16, 3]
,最终reshape为[N, 196, 16²×3] = [N, 196, 768]
。
核心作用
将原始二维数据转换为模型可处理的序列格式([N, L, D]
,D=p²×C
),作为编码器的输入(后续会被掩码并编码)。
二、unpatchify
:将Patch序列还原为原始数据
功能
将形状为 [N, L, p²×C]
的Patch序列(模型解码器输出的重建结果)还原为 [N, C, H, W]
的原始数据形状,用于与输入数据对比计算重建损失。
方法定义
def unpatchify(self, x):"""x: (N, L, patch_size**2 * C) # Patch序列(重建结果)specs: (N, C, H, W) # 还原的原始数据形状(音频频谱图)"""p = self.patch_embed.patch_size[0] # Patch大小# 音频频谱图的高度和宽度(H=1024,W=128,根据实际场景调整)h = 1024 // p w = 128 // p # 形状转换:[N, L, p²×1] → [N, h, w, p, p, 1]x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))# 调整维度顺序:[N, h, w, p, p, 1] → [N, 1, h, p, w, p]x = torch.einsum('nhwpqc->nchpwq', x)# 合并Patch内的像素,还原为原始尺寸:[N, 1, h*p, w*p] = [N, 1, 1024, 128]specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))return specs
关键逻辑
unpatchify
是patchify
的逆操作,核心是将“Patch序列”的维度重新映射回“原始数据”的空间维度(H, W):
- reshape拆分:将
[N, L, p²×1]
拆分为[N, h, w, p, p, 1]
,其中h×w=L
(Patch总数),h=H//p
,w=W//p
。 - 维度顺序调整:通过
torch.einsum('nhwpqc->nchpwq'
将通道维度(c
)提前,得到[N, 1, h, p, w, p]
,与patchify
中的中间形状对应。 - 合并还原:将
h
个p
合并为H=h×p
,w
个p
合并为W=w×p
,最终得到[N, 1, H, W]
的原始频谱图形状。
适配场景
代码中unpatchify
主要适配音频频谱图(C=1
),若处理图像(C=3
),只需将通道数改为3,并调整h
和w
的计算(如h=w=224//16=14
),逻辑完全一致。
三、patchify
与unpatchify
的互逆性与核心作用
-
互逆性:
对原始数据imgs
执行unpatchify(patchify(imgs))
,结果应与imgs
完全一致(忽略数值误差),这是确保重建损失计算准确的前提。 -
在MAE中的作用:
patchify
:将输入imgs
转换为Patch序列,作为编码器的输入(后续被掩码、编码)。unpatchify
:将解码器输出的重建Patch序列(pred
)还原为原始数据形状,与imgs
计算MSE损失(forward_loss
中使用),迫使模型学习从部分Patch重建完整数据。
总结
patchify
和unpatchify
是MAE中连接“原始数据”与“模型序列输入/输出”的桥梁:
patchify
通过滑动窗口(重叠)或均匀分割(非重叠)将二维数据转换为Patch序列,适配模型的Transformer结构;unpatchify
通过维度重组将Patch序列还原为原始形状,确保重建损失的准确计算。
两者的设计充分考虑了音频(单通道、重叠Patch)和图像(三通道、非重叠Patch)的差异,体现了模型的多模态适配能力。
2.4 掩码机制:random_masking
与random_masking_2d
在掩码自编码器(MAE)中,掩码机制是核心设计之一,其作用是随机掩盖输入数据的部分Patch,迫使模型从剩余的少量Patch中学习重建完整数据,从而提取更鲁棒的特征表示。代码中实现了两种掩码机制:random_masking
(1D序列掩码)和random_masking_2d
(2D结构化掩码),分别适配不同的数据结构(图像的1D Patch序列 vs 音频频谱图的2D时间-频率结构)。以下是详细解析:
一、random_masking
:1D序列掩码(适用于图像等1D Patch序列)
功能
将输入的Patch序列(形状[N, L, D]
,N
=批次,L
=总Patch数,D
=特征维度)视为一维序列,按指定比例(mask_ratio
)随机掩盖部分Patch,输出掩码后的序列、掩码矩阵及原始顺序恢复索引。
方法定义
def random_masking(self, x, mask_ratio):"""对一维Patch序列进行随机掩码x: [N, L, D] # 输入的Patch序列返回:x_masked: [N, L_keep, D] # 掩码后保留的Patch序列(L_keep = L*(1-mask_ratio))mask: [N, L] # 掩码矩阵(0=保留,1=掩盖)ids_restore: [N, L] # 恢复原始顺序的索引"""N, L, D = x.shape # 解析输入维度len_keep = int(L * (1 - mask_ratio)) # 计算需要保留的Patch数量# 生成随机噪声(用于决定哪些Patch被保留)noise = torch.rand(N, L, device=x.device) # 形状[N, L],值在[0,1)之间# 对噪声排序,得到掩盖/保留的索引ids_shuffle = torch.argsort(noise, dim=1) # 按噪声升序排序,小值对应保留的Patchids_restore = torch.argsort(ids_shuffle, dim=1) # 用于恢复原始顺序的索引(对shuffle索引再排序)# 提取需要保留的Patchids_keep = ids_shuffle[:, :len_keep] # 取前len_keep个索引(噪声最小的Patch)x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # 按索引提取保留的Patch# 生成掩码矩阵(0=保留,1=掩盖)mask = torch.ones([N, L], device=x.device) # 初始化全为1(默认掩盖)mask[:, :len_keep] = 0 # 前len_keep个位置设为0(保留)mask = torch.gather(mask, dim=1, index=ids_restore) # 按原始顺序恢复掩码矩阵return x_masked, mask, ids_restore
关键步骤解析
- 确定保留数量:根据
mask_ratio
计算保留的Patch数量(len_keep = L*(1-mask_ratio)
),例如mask_ratio=0.75
时保留25%的Patch。 - 随机噪声生成:生成
[N, L]
的随机噪声,每个Patch对应一个噪声值,用于后续排序(噪声越小的Patch越可能被保留)。 - 索引排序:
ids_shuffle
:对噪声按行排序的索引(升序),即“噪声最小的Patch排在最前”。ids_keep
:取ids_shuffle
的前len_keep
个索引,即需要保留的Patch索引。
- 提取保留的Patch:用
torch.gather
按ids_keep
从原始序列x
中提取保留的Patch,得到x_masked
(形状[N, len_keep, D]
)。 - 生成掩码矩阵:
- 初始掩码全为1(表示掩盖),前
len_keep
个位置设为0(表示保留)。 - 用
ids_restore
(ids_shuffle
的逆排序)将掩码矩阵恢复为原始Patch顺序,确保mask[i,j]
对应原始序列中第j
个Patch是否被掩盖。
- 初始掩码全为1(表示掩盖),前
核心特点
- 无差别掩码:将Patch视为一维序列,不区分空间/时序位置,按比例随机掩盖,适用于图像等“全局结构无明显维度差异”的数据。
- 高效性:通过噪声排序实现随机掩码,避免显式采样,计算效率高。
二、random_masking_2d
:2D结构化掩码(适用于音频频谱图等2D结构)
功能
针对具有二维结构的Patch(如音频频谱图的“时间-频率”维度),分别对时间维度(T
)和频率维度(F
)按指定比例(mask_t_prob
和mask_f_prob
)进行掩码,最终输出掩码后的序列、合并的掩码矩阵及原始顺序恢复索引。相比1D掩码,2D掩码更贴合数据的物理结构(如音频的时间连续性和频率相关性)。
方法定义
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):"""对二维结构的Patch(时间T×频率F)进行掩码x: [N, L, D] # 输入的Patch序列(L = T*F)返回:x_masked: [N, L_keep, D] # 掩码后保留的Patch序列(L_keep = T_keep*F_keep)mask: [N, L] # 合并的掩码矩阵(0=保留,1=掩盖)ids_restore: [N, L] # 恢复原始顺序的索引"""N, L, D = x.shape # 解析输入维度# 定义时间(T)和频率(F)维度的大小(根据Patch划分方式)if self.use_custom_patch: # 重叠Patch(音频场景)T = 101 # 时间维度Patch数F = 12 # 频率维度Patch数else: # 非重叠PatchT = 64F = 8# 计算时间和频率维度保留的数量len_keep_t = int(T * (1 - mask_t_prob)) # 时间维度保留数len_keep_f = int(F * (1 - mask_f_prob)) # 频率维度保留数# -------------------------- 时间维度掩码 --------------------------noise_t = torch.rand(N, T, device=x.device) # 时间维度随机噪声[N, T]ids_shuffle_t = torch.argsort(noise_t, dim=1) # 时间维度排序索引(升序)ids_restore_t = torch.argsort(ids_shuffle_t, dim=1) # 时间维度恢复索引ids_keep_t = ids_shuffle_t[:, :len_keep_t] # 时间维度保留的索引# -------------------------- 频率维度掩码 --------------------------noise_f = torch.rand(N, F, device=x.device) # 频率维度随机噪声[N, F]ids_shuffle_f = torch.argsort(noise_f, dim=1) # 频率维度排序索引(升序)ids_restore_f = torch.argsort(ids_shuffle_f, dim=1) # 频率维度恢复索引ids_keep_f = ids_shuffle_f[:, :len_keep_f] # 频率维度保留的索引# -------------------------- 合并掩码矩阵 --------------------------# 生成时间维度掩码(0=保留,1=掩盖)并扩展为[N, T, F]mask_t = torch.ones(N, T, device=x.device)mask_t[:, :len_keep_t] = 0mask_t = torch.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1, F, 1).permute(0, 2, 1) # [N, T, F]# 生成频率维度掩码(0=保留,1=掩盖)并扩展为[N, T, F]mask_f = torch.ones(N, F, device=x.device)mask_f[:, :len_keep_f] = 0mask_f = torch.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1, T, 1) # [N, T, F]# 合并掩码:时间或频率任一被掩盖,则该Patch被掩盖(1 - (1-时间掩码)*(1-频率掩码))mask = 1 - (1 - mask_t) * (1 - mask_f) # [N, T, F]mask = mask.flatten(start_dim=1) # 展平为[N, L](L=T*F)# -------------------------- 提取保留的Patch --------------------------# 生成原始Patch索引,并给被掩盖的Patch加一个大值(确保排序后被放到最后)id2res = torch.Tensor(list(range(N*T*F))).reshape(N, T, F).to(x.device) # [N, T, F]的原始索引id2res = id2res + 999 * mask # 被掩盖的Patch索引 += 999(值变大)# 展平后排序,取前L_keep个索引(未被掩盖的Patch)id2res2 = torch.argsort(id2res.flatten(start_dim=1)) # 按索引值升序排序ids_keep = id2res2.flatten(start_dim=1)[:, :len_keep_t*len_keep_f] # 保留的Patch索引x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # 提取保留的Patch# 生成恢复原始顺序的索引ids_restore = torch.argsort(id2res2.flatten(start_dim=1)) # [N, L]return x_masked, mask, ids_restore
关键步骤解析
2D掩码的核心是区分时间和频率维度独立掩码,再合并结果,步骤如下:
-
维度定义:根据
use_custom_patch
确定二维结构的时间维度(T
)和频率维度(F
),例如音频重叠Patch时T=101
、F=12
(总Patch数L=T*F=1212
)。 -
时间维度掩码:
- 生成
[N, T]
的随机噪声noise_t
,排序后得到保留的时间索引ids_keep_t
(数量len_keep_t = T*(1-mask_t_prob)
)。 - 生成时间掩码
mask_t
([N, T]
,0=保留,1=掩盖),扩展为[N, T, F]
(与频率维度对齐)。
- 生成
-
频率维度掩码:
- 生成
[N, F]
的随机噪声noise_f
,排序后得到保留的频率索引ids_keep_f
(数量len_keep_f = F*(1-mask_f_prob)
)。 - 生成频率掩码
mask_f
([N, F]
,0=保留,1=掩盖),扩展为[N, T, F]
(与时间维度对齐)。
- 生成
-
合并掩码矩阵:
- 合并逻辑:
mask = 1 - (1 - mask_t) * (1 - mask_f)
,即“时间或频率任一维度被掩盖,则该Patch被掩盖”(避免仅掩盖单维度导致信息残留)。 - 合并后展平为
[N, L]
(L=T*F
),与原始Patch序列长度一致。
- 合并逻辑:
-
提取保留的Patch:
- 为原始Patch索引(
id2res
)中被掩盖的位置加“大值”(999),确保排序后这些Patch被放在最后。 - 对索引排序后,取前
len_keep_t*len_keep_f
个索引(未被掩盖的Patch),用torch.gather
提取得到x_masked
。
- 为原始Patch索引(
-
生成恢复索引:
ids_restore
为排序索引的逆,用于解码器将掩码后的序列还原到原始顺序,确保重建时Patch位置对齐。
核心特点
- 维度感知:区分时间和频率维度分别掩码,更贴合音频频谱图等二维数据的物理结构(如时间连续性、频率相关性)。
- 灵活控制:通过
mask_t_prob
和mask_f_prob
独立控制两个维度的掩码强度(例如对时间维度掩盖60%,频率维度掩盖50%)。 - 严格掩盖:合并掩码时采用“或”逻辑,确保被掩盖的Patch在至少一个维度上缺失信息,增强重建难度,迫使模型学习更本质的特征。
三、两种掩码机制的对比与适用场景
对比项 | random_masking (1D) | random_masking_2d (2D) |
---|---|---|
数据结构假设 | 一维Patch序列(无维度差异) | 二维结构化Patch(如时间-频率) |
掩码粒度 | 整个序列无差别掩码 | 区分两个维度独立掩码 |
掩码比例控制 | 单一mask_ratio 控制整体比例 | mask_t_prob 和mask_f_prob 分别控制 |
适用场景 | 图像(正方形Patch,全局结构) | 音频频谱图(时间-频率结构) |
核心优势 | 实现简单,计算高效 | 贴合二维数据物理结构,特征学习更有效 |
总结
random_masking
和random_masking_2d
是MAE中实现“信息缺失驱动学习”的核心机制:
- 1D掩码适用于无明显维度差异的序列数据,通过无差别掩码迫使模型学习全局结构;
- 2D掩码适用于二维结构化数据(如音频频谱图),通过维度感知的掩码策略,更精准地模拟真实场景中的信息缺失(如局部时间片段或频率成分丢失)。
两种机制均通过“随机噪声排序-索引提取-掩码矩阵生成”的流程实现高效掩码,并通过恢复索引确保解码器能正确还原序列顺序,为后续重建任务奠定基础。
2.5 编码器前向传播:forward_encoder
与forward_encoder_no_mask
在 MaskedAutoencoderViT
模型中,编码器的前向传播方法分为两种:forward_encoder
(带掩码的前向传播)和 forward_encoder_no_mask
(无掩码的前向传播)。两者均基于 Transformer 架构提取输入数据的特征,但适用场景和处理流程存在显著差异:
forward_encoder
是 MAE 自监督训练的核心,通过随机掩码部分 Patch 迫使模型从有限信息中学习数据结构;forward_encoder_no_mask
用于处理完整输入(无掩码),主要用于预训练后的特征提取或下游任务微调。
一、forward_encoder
:带掩码的编码器前向传播(自监督训练核心)
功能
接收原始输入数据,通过 Patch 嵌入、位置编码、随机掩码部分 Patch 后,经 Transformer 编码器提取特征,输出掩码后的特征序列、掩码矩阵及原始顺序恢复索引,为解码器的重建任务提供输入。
方法定义
def forward_encoder(self, x, mask_ratio, mask_2d=False):# 1. 对输入数据进行Patch嵌入(转换为Patch序列)x = self.patch_embed(x)# 2. 为Patch添加位置嵌入(不含cls_token的位置)x = x + self.pos_embed[:, 1:, :]# 3. 对Patch序列进行随机掩码(1D或2D掩码)if mask_2d:x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)else:x, mask, ids_restore = self.random_masking(x, mask_ratio)# 4. 拼接cls_token(并添加其位置嵌入)cls_token = self.cls_token + self.pos_embed[:, :1, :] # cls_token的位置嵌入是pos_embed的第0位cls_tokens = cls_token.expand(x.shape[0], -1, -1) # 扩展到批次维度:[1,1,D] → [N,1,D]x = torch.cat((cls_tokens, x), dim=1) # 拼接:[N, L_keep, D] → [N, L_keep+1, D](+1为cls_token)# 5. 通过Transformer编码器块提取特征for blk in self.blocks:x = blk(x)x = self.norm(x) # 编码器输出归一化return x, mask, ids_restore, None # 返回编码特征、掩码、恢复索引
关键步骤解析
1. Patch嵌入(self.patch_embed(x)
)
- 输入
x
为原始数据(图像/音频频谱图,形状[N, C, H, W]
),经patch_embed
转换为 Patch 序列(形状[N, L, embed_dim]
,L
为总 Patch 数,embed_dim
为编码器特征维度)。 - 例如:音频频谱图
[N, 1, 1024, 128]
经重叠 Patch 嵌入后变为[N, 1212, 1024]
(L=101×12=1212
,embed_dim=1024
)。
2. 添加位置嵌入(x = x + self.pos_embed[:, 1:, :]
)
- 位置嵌入
self.pos_embed
形状为[1, L+1, embed_dim]
(+1
为cls_token
预留位置),此处仅取[:, 1:, :]
(即 Patch 对应的位置嵌入),与 Patch 序列x
相加,赋予模型空间/时序位置感知能力。
3. 随机掩码(random_masking
或 random_masking_2d
)
- 根据
mask_2d
选择掩码方式:- 1D 掩码(
mask_2d=False
):调用random_masking
,按mask_ratio
随机掩盖部分 Patch,输出掩码后的序列x
(形状[N, L_keep, embed_dim]
,L_keep
为保留的 Patch 数)。 - 2D 掩码(
mask_2d=True
):调用random_masking_2d
,按mask_t_prob
(时间维度)和mask_f_prob
(频率维度)掩盖 Patch,输出掩码后的序列。
- 1D 掩码(
- 同时返回
mask
(掩码矩阵,[N, L]
,0=保留,1=掩盖)和ids_restore
(恢复原始顺序的索引,[N, L]
),用于解码器还原序列顺序。
4. 拼接 cls_token
(torch.cat((cls_tokens, x), dim=1)
)
cls_token
是可学习的全局特征向量(形状[1, 1, embed_dim]
),先添加其专属位置嵌入(self.pos_embed[:, :1, :]
),再扩展到批次维度([N, 1, embed_dim]
)。- 与掩码后的 Patch 序列
x
拼接,得到[N, L_keep+1, embed_dim]
(+1
为cls_token
),使cls_token
能聚合所有保留 Patch 的全局信息。
5. Transformer 编码(for blk in self.blocks: x = blk(x)
)
- 输入序列经
depth
个 Transformer 块(self.blocks
)处理,每个块包含“多头自注意力”和“MLP”,通过残差连接和层归一化提取高层特征。 - 最终经
self.norm
归一化,输出编码特征x
(形状[N, L_keep+1, embed_dim]
)。
输出与作用
- 输出:编码特征
x
(掩码后 Patch 序列+cls_token
的特征)、掩码矩阵mask
、原始顺序索引ids_restore
。 - 作用:为解码器提供“已编码的保留 Patch 特征”,结合
mask
和ids_restore
,使解码器能定位被掩盖的 Patch 位置并重建。
二、forward_encoder_no_mask
:无掩码的编码器前向传播(特征提取/微调)
功能
处理完整输入数据(不进行掩码),通过 Transformer 编码器提取特征,主要用于预训练后对完整数据的特征提取(如下游分类任务微调),或获取上下文丰富的特征表示。
方法定义
def forward_encoder_no_mask(self, x):# 1. 对输入数据进行Patch嵌入x = self.patch_embed(x)# 2. 为Patch添加位置嵌入(不含cls_token的位置)x = x + self.pos_embed[:, 1:, :]# 3. 拼接cls_token(并添加其位置嵌入)cls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1) # [N, L+1, embed_dim](L为总Patch数,无掩码)# 4. 通过Transformer编码器块,收集深层上下文特征contextual_embs = []for n, blk in enumerate(self.blocks):x = blk(x)# 收集contextual_depth之后的层输出(经归一化)if n > self.contextual_depth:contextual_embs.append(self.norm(x))# 对收集的深层特征取平均,作为最终上下文特征contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)return contextual_emb
关键步骤解析
1-3. Patch嵌入、位置编码与 cls_token
拼接
- 流程与
forward_encoder
前3步一致,但不进行掩码,因此 Patch 序列为完整序列(L
个 Patch,无删减),拼接cls_token
后形状为[N, L+1, embed_dim]
。
4. 收集深层上下文特征(contextual_embs
)
- 与
forward_encoder
直接通过所有 Transformer 块后输出不同,forward_encoder_no_mask
会收集多个深层 Transformer 块的输出:self.contextual_depth
为阈值(如8),仅收集索引n > contextual_depth
的 Transformer 块输出(即更深层的特征)。- 每个深层输出经
self.norm
归一化后存入contextual_embs
,最终通过torch.stack
拼接并取平均(mean(dim=0)
),得到contextual_emb
(形状[N, L+1, embed_dim]
)。
设计动机
深层 Transformer 块的输出通常包含更抽象、更全局的特征(浅层特征偏向局部细节),通过收集多个深层特征并取平均,可获得更鲁棒的上下文特征,提升下游任务(如分类)的性能。
输出与作用
- 输出:
contextual_emb
(深层 Transformer 特征的平均值,[N, L+1, embed_dim]
)。 - 作用:为下游任务提供高质量特征(如取
contextual_emb[:, 0, :]
作为全局特征用于分类),或用于模型微调阶段的特征学习。
三、两种编码器前向传播的对比
对比项 | forward_encoder (带掩码) | forward_encoder_no_mask (无掩码) |
---|---|---|
核心操作 | 包含随机掩码步骤(掩盖部分Patch) | 无掩码,使用完整Patch序列 |
输入处理 | 输出掩码后的Patch序列特征 | 输出完整Patch序列的深层特征平均值 |
关键输出 | 编码特征、掩码矩阵、恢复索引 | 上下文特征(深层特征平均) |
适用场景 | MAE自监督预训练(配合解码器重建) | 预训练后特征提取、下游任务微调 |
设计目标 | 迫使模型从有限信息学习数据结构 | 提取鲁棒的完整特征用于下游任务 |
序列长度 | 短(L_keep + 1 ,仅保留部分Patch) | 长(L + 1 ,包含所有Patch) |
总结
forward_encoder
和 forward_encoder_no_mask
是编码器针对不同阶段设计的前向传播方法:
forward_encoder
是 MAE 自监督训练的核心,通过掩码机制创造“信息缺失”场景,驱动模型学习数据的内在结构,为解码器的重建任务提供输入;forward_encoder_no_mask
专注于完整输入的特征提取,通过聚合深层 Transformer 特征,为下游任务提供高质量的上下文特征,实现预训练模型的迁移应用。
两者共同构成了 MAE“预训练-微调”全流程的编码器逻辑,兼顾了自监督学习的特征学习能力和下游任务的实用性。
2.6 解码器前向传播:forward_decoder
forward_decoder
是掩码自编码器(MAE)中负责“重建被掩盖Patch”的核心方法。它接收编码器输出的“未掩盖Patch特征”,结合掩码Token和位置信息,通过解码器网络(Transformer或Swin Transformer)重建完整的Patch序列,最终输出可与原始数据对比的重建结果。以下是详细解析:
功能与输入输出
- 功能:从编码器输出的“掩码后特征”中恢复完整的Patch序列,实现被掩盖区域的重建。
- 输入:
x
:编码器输出的特征(形状[N, L_keep+1, embed_dim]
,L_keep
为未掩盖的Patch数,+1
为cls_token
)。ids_restore
:用于将掩码后的序列恢复为原始Patch顺序的索引(形状[N, L]
,L
为总Patch数)。
- 输出:
pred
:重建的Patch序列(形状[N, L, p²×C]
,p
为Patch大小,C
为通道数),可通过unpatchify
还原为原始数据形状。
方法定义与详细步骤
def forward_decoder(self, x, ids_restore):# 1. 将编码器特征投影到解码器维度x = self.decoder_embed(x) # [N, L_keep+1, embed_dim] → [N, L_keep+1, decoder_embed_dim]# 2. 生成并拼接掩码Token(填充被掩盖的Patch位置)# 计算需要填充的掩码Token数量:总Patch数 + 1(cls_token) - 编码器输出长度mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)# 拼接未掩盖的Patch特征(不含cls_token)和掩码Token → [N, L_keep + (L - L_keep), decoder_embed_dim] = [N, L, decoder_embed_dim]x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # 去掉cls_token,拼接掩码Token# 按ids_restore恢复原始Patch顺序(确保位置对齐)x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))# 拼接回cls_token → [N, L+1, decoder_embed_dim](L+1 = 总Patch数 + cls_token)x = torch.cat([x[:, :1, :], x_], dim=1)# 3. 添加解码器位置嵌入(确保位置信息对齐)x = x + self.decoder_pos_embed # [N, L+1, decoder_embed_dim]# 4. 适配Swin解码器的二维结构(若启用)if self.decoder_mode != 0:B, L, D = x.shapex = x[:, 1:, :] # 移除cls_token,仅保留Patch特征 → [N, L, decoder_embed_dim]if self.use_custom_patch: # 重叠Patch(音频场景)# 调整形状为二维结构(时间T×频率F),并补全尺寸(避免维度不匹配)x = x.reshape(B, 101, 12, D)x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # 补全一行,适配Swin窗口x = x.reshape(B, 1224, D) # 102×12=1224# 5. 通过解码器块提取特征(Transformer或Swin Transformer)if self.decoder_mode > 3: # 预留的mvit解码器模式x = self.decoder_blocks(x)else:for blk in self.decoder_blocks:x = blk(x) # 逐块处理# 6. 解码器输出归一化x = self.decoder_norm(x) # [N, L+1 (或L), decoder_embed_dim]# 7. 预测Patch的像素值(投影到原始Patch维度)pred = self.decoder_pred(x) # [N, L+1 (或L), p²×C]# 8. 移除cls_token,仅保留Patch的重建结果if self.decoder_mode != 0:if self.use_custom_patch:# 还原重叠Patch的形状,去掉补全的部分pred = pred.reshape(B, 102, 12, 256)pred = pred[:, :101, :, :] # 去掉补全的一行pred = pred.reshape(B, 1212, 256) # 101×12=1212(原始总Patch数)else:pred = pred[:, 1:, :] # 普通解码器直接移除cls_tokenreturn pred, None, None # 返回重建的Patch序列
关键步骤解析
1. 编码器特征投影(self.decoder_embed(x)
)
- 编码器输出特征的维度为
embed_dim
(如1024),而解码器的特征维度为decoder_embed_dim
(如512)。这一步通过线性层(self.decoder_embed
)将维度转换,确保与解码器后续层的输入维度匹配。 - 形状变化:
[N, L_keep+1, embed_dim] → [N, L_keep+1, decoder_embed_dim]
。
2. 掩码Token拼接与顺序恢复
这是解码器重建的核心步骤,目的是填充被掩盖的Patch位置并恢复原始顺序:
- 掩码Token生成:
mask_tokens
是可学习的向量(self.mask_token
),数量为ids_restore.shape[1] + 1 - x.shape[1]
(即总Patch数L
+ 1(cls_token) - 编码器输出长度L_keep+1
=L - L_keep
,恰好等于被掩盖的Patch数)。 - 拼接未掩盖特征与掩码Token:
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
去掉编码器输出中的cls_token
,拼接掩码Token,得到长度为L_keep + (L - L_keep) = L
的序列(所有Patch位置均被填充,未掩盖位置为编码器特征,掩盖位置为掩码Token)。 - 恢复原始顺序:
torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
利用ids_restore
将拼接后的序列重新排列为原始Patch顺序(确保重建的Patch与输入数据的位置一一对应)。 - 拼接cls_token:最后将
cls_token
拼接回序列,得到[N, L+1, decoder_embed_dim]
(与解码器位置嵌入的形状匹配)。
3. 添加解码器位置嵌入(x = x + self.decoder_pos_embed
)
- 解码器位置嵌入(
self.decoder_pos_embed
)形状为[1, L+1, decoder_embed_dim]
,与编码器位置嵌入结构一致,但维度适配解码器。添加位置嵌入后,解码器能正确理解Patch的空间/时序位置关系,确保重建的空间一致性。
4. 适配Swin解码器的二维结构(if self.decoder_mode != 0
)
当使用Swin Transformer作为解码器(decoder_mode=1
)时,需将序列特征调整为二维结构(适配Swin的窗口注意力机制):
- 移除cls_token:Swin解码器专注于Patch的空间重建,无需
cls_token
,因此去掉x[:, 1:, :]
。 - 形状调整:对于重叠Patch(音频场景),将序列
[N, L, D]
重塑为[N, T, F, D]
(T
=时间维度Patch数,F
=频率维度Patch数,如101×12),并通过torch.cat
补全一行(适配Swin的窗口大小,避免维度不匹配),最后重塑回序列格式输入解码器块。
5. 解码器块特征提取(for blk in self.decoder_blocks: x = blk(x)
)
- 根据
decoder_mode
选择解码器类型:- 普通Transformer解码器(
decoder_mode=0
):通过decoder_depth
个标准Transformer块(Block
)处理,每个块含多头自注意力和MLP,捕捉Patch间的长距离依赖,优化重建特征。 - Swin Transformer解码器(
decoder_mode=1
):通过16个SwinTransformerBlock
处理,利用窗口注意力和移位窗口机制,更高效地捕捉二维结构(如音频时间-频率)的局部相关性。
- 普通Transformer解码器(
6. 归一化与像素预测(self.decoder_norm(x)
和 self.decoder_pred(x)
)
- 归一化:
self.decoder_norm
对解码器块的输出进行层归一化,稳定特征分布。 - 像素预测:
self.decoder_pred
是线性层,将解码器特征(decoder_embed_dim
)投影到“单个Patch的像素维度”(p²×C
,如16×16×1=256 for 音频单通道),得到每个Patch的重建像素值。
7. 移除cls_token(最终输出处理)
- 重建目标是原始数据的Patch序列,因此需移除
cls_token
:- 普通解码器直接取
pred[:, 1:, :]
。 - Swin解码器(重叠Patch)需先还原二维形状,去掉之前补全的行,再重塑为原始总Patch数
L
,确保输出形状与patchify
后的原始Patch序列一致。
- 普通解码器直接取
核心作用与设计逻辑
解码器的核心目标是从“部分观察”(未掩盖的Patch)推断“全局完整信息”(被掩盖的Patch),其设计逻辑围绕以下原则:
- 信息补全:通过掩码Token填充被掩盖位置,为解码器提供完整的序列结构。
- 位置对齐:解码器位置嵌入确保重建的Patch与原始数据的空间/时序位置一一对应。
- 结构适配:支持Transformer(全局注意力)和Swin Transformer(局部窗口注意力),分别适配图像的全局结构和音频的二维时间-频率结构。
- 维度匹配:通过投影层确保编码器特征与解码器维度兼容,最终预测层输出与原始Patch的像素维度一致,为损失计算(
forward_loss
)提供可对比的重建结果。
总结
forward_decoder
是MAE实现“重建任务”的核心模块,通过“特征投影-掩码填充-顺序恢复-位置编码-深度特征提取-像素预测”的流程,将编码器输出的有限信息转化为完整的Patch序列重建结果。其设计兼顾了多模态数据(图像/音频)的结构差异,通过灵活的解码器类型(Transformer/Swin)和形状调整,确保重建精度和效率,最终辅助模型学习数据的内在结构特征。
2.7 损失计算:forward_loss
forward_loss
是掩码自编码器(MAE)中计算“重建损失”的核心方法,用于衡量模型重建的被掩码Patch与原始数据中对应Patch的差异。该损失是MAE自监督训练的优化目标,直接驱动模型学习从部分信息中重建完整数据的能力。以下是详细解析:
功能与输入输出
- 功能:仅针对被掩码的Patch计算重建误差(忽略未掩码的Patch),确保模型专注于学习从有限信息中恢复缺失内容。
- 输入:
imgs
:原始输入数据(形状[N, C, H, W]
,N
=批次,C
=通道,H/W
=高度/宽度)。pred
:解码器输出的重建Patch序列(形状[N, L, p²×C]
,L
=总Patch数,p
=Patch大小)。mask
:掩码矩阵(形状[N, L]
,0=未掩码,1=被掩码)。
- 输出:
- 被掩码Patch的平均重建损失(标量,通常为MSE损失)。
方法定义与详细步骤
def forward_loss(self, imgs, pred, mask):"""imgs: [N, C, H, W] # 原始输入数据pred: [N, L, p²×C] # 解码器重建的Patch序列mask: [N, L] # 掩码矩阵(1表示被掩码的Patch)"""# 1. 将原始图像分割为Patch序列(与pred的形状匹配)target = self.patchify(imgs) # [N, L, p²×C]# 2. 计算重建Patch与原始Patch的MSE损失(逐元素)loss = (pred - target) **2 # [N, L, p²×C],每个元素的平方误差loss = loss.mean(dim=-1) # 对单个Patch内的所有像素求平均 → [N, L]# 3. 仅计算被掩码Patch的损失(mask=1的位置),并求平均loss = (loss * mask).sum() / mask.sum() # 被掩码Patch的平均损失return loss
#####** 关键步骤解析 **
1. 原始数据Patch化(target = self.patchify(imgs)
- 调用
patchify
方法将原始输入imgs
分割为与pred
形状完全一致的Patch序列target
(形状[N, L, p²×C]
)。 - 这一步确保“重建结果”与“原始数据”在Patch级别对齐,可直接计算逐Patch的误差。
- 例如:原始音频频谱图
[N, 1, 1024, 128]
经patchify
后变为[N, 1212, 256]
(L=101×12=1212
,p=16
,p²×C=16²×1=256
),与pred
形状完全匹配。
2. 逐元素MSE计算(`loss = (pred - target)
-** 逐元素平方误差 :计算重建Patch pred
与原始Patch target
的差值平方,得到形状 [N, L, p²×C]
的误差矩阵(每个元素对应单个像素的重建误差)。
- 单个Patch内平均 **:通过 loss.mean(dim=-1)
对每个Patch的所有像素(p²×C
维度)求平均,得到 [N, L]
的误差矩阵(每个元素对应一个Patch的平均重建误差)。
#####** 3. 掩码Patch的损失聚合(loss = (loss * mask).sum() / mask.sum()
)- 这是MAE损失计算的核心设计: 仅关注被掩码的Patch **,忽略未掩码的Patch(因为未掩码的Patch已被编码器直接观察到,重建它们无法体现模型的推断能力)。
loss * mask
:通过掩码矩阵mask
(1=被掩码,0=未掩码)过滤误差,仅保留被掩码Patch的误差(未掩码Patch的误差被置为0)。sum()
:对所有被掩码Patch的误差求和。/ mask.sum()
:除以被掩码Patch的总数量(mask.sum()
得到批次中被掩码的Patch总数),得到被掩码Patch的平均重建损失。
###** 核心设计原则 1. 聚焦掩码区域 :仅计算被掩码Patch的损失,迫使模型学习从“未观察信息”(未掩码Patch)推断“缺失信息”(被掩码Patch),而非简单复制已观察内容。
2. 平均化处理 **:
- 先对单个Patch内的所有像素求平均(
mean(dim=-1)
),避免单个Patch因像素数量多(如p=16
时有256个像素)而主导损失。 - 再对所有被掩码Patch求平均(
sum() / mask.sum()
),确保不同批次中掩码数量不同时损失仍可比较。
3.** 与Patch划分一致 **:依赖patchify
方法确保pred
与target
的形状严格匹配,保证误差计算的准确性。
###** 与其他损失计算的对比 - 全Patch损失 :若计算所有Patch(包括未掩码)的损失,模型可能“偷懒”复制未掩码Patch的信息,而不学习真正的推断能力。
- 逐像素损失 **:若直接对原始图像和重建图像计算逐像素损失(不通过Patch),会包含未掩码区域的误差,违背MAE“掩码重建”的设计初衷。
- MAE的
forward_loss
通过聚焦掩码区域,精准对齐了自监督学习的目标:** 从部分观察中学习数据的全局结构 **。
###** 总结**
forward_loss
是MAE训练的“指挥棒”,通过以下逻辑驱动模型优化:
- 将原始数据和重建结果转换为Patch序列,确保误差计算在相同粒度上进行;
- 仅对被掩码的Patch计算平均MSE损失,引导模型专注于学习从有限信息中恢复缺失内容的能力;
- 损失的平均化处理保证了不同批次、不同掩码数量下的损失可比较性。
这一设计使MAE能高效学习数据的内在特征,为下游任务(如分类、分割)提供高质量的预训练模型。
2.8 模型总前向传播:forward
forward
是 MaskedAutoencoderViT
模型的总前向传播方法,它整合了编码器(带掩码)、解码器、损失计算等核心组件,实现了从原始输入数据到重建损失(或重建结果)的完整流程。该方法是模型训练和推理的入口,直接体现了MAE“掩码-编码-解码-重建-损失”的核心逻辑。以下是详细解析:
功能与输入输出
- 功能:协调模型各模块(编码器、解码器、损失计算),完成“输入数据→掩码编码→解码重建→损失计算”的全流程。
- 输入:
imgs
:原始输入数据(图像或音频频谱图,形状[N, C, H, W]
,N
=批次,C
=通道,H/W
=高度/宽度)。mask_ratio
:掩码比例(如0.75,表示掩盖75%的Patch)。mask_2d
:是否使用2D掩码(True
用于音频等二维结构数据,False
用于图像)。
- 输出:
- 训练阶段:返回被掩码Patch的重建损失(标量)。
- 推理阶段:可额外返回重建的原始数据(通过
unpatchify
还原),用于可视化或评估。
方法定义与详细步骤
def forward(self, imgs, mask_ratio=0.75, mask_2d=False):# 1. 编码器前向传播(带掩码):得到编码特征、掩码矩阵、原始顺序索引latent, mask, ids_restore, _ = self.forward_encoder(imgs, mask_ratio, mask_2d=mask_2d)# 2. 解码器前向传播:根据编码特征和索引重建Patch序列pred, _, _ = self.forward_decoder(latent, ids_restore) # pred: [N, L, p²×C]# 3. 计算重建损失(仅针对被掩码的Patch)loss = self.forward_loss(imgs, pred, mask)# 4. (可选)还原重建的原始数据形状(用于可视化或推理)pred_spec = self.unpatchify(pred) # [N, C, H, W],与原始输入形状一致return loss, pred_spec, mask
关键步骤解析
1. 编码器带掩码前向传播(self.forward_encoder
)
- 输入:原始数据
imgs
、掩码比例mask_ratio
、掩码类型mask_2d
。 - 处理逻辑:
- 将
imgs
转换为Patch序列并添加位置嵌入。 - 根据
mask_2d
选择1D或2D掩码机制,随机掩盖部分Patch。 - 通过Transformer编码器提取保留Patch的特征,拼接
cls_token
并归一化。
- 将
- 输出:
latent
:编码器输出的特征(形状[N, L_keep+1, embed_dim]
,L_keep
为保留的Patch数,+1
为cls_token
)。mask
:掩码矩阵([N, L]
,1=被掩码,0=未掩码)。ids_restore
:用于将掩码后序列恢复为原始顺序的索引([N, L]
)。
2. 解码器前向传播(self.forward_decoder
)
- 输入:编码器特征
latent
、恢复索引ids_restore
。 - 处理逻辑:
- 将
latent
投影到解码器维度(decoder_embed_dim
)。 - 生成掩码Token(
mask_token
),填充被掩码的Patch位置,并通过ids_restore
恢复原始顺序。 - 添加解码器位置嵌入,经解码器(Transformer或Swin Transformer)提取特征后,预测每个Patch的像素值。
- 将
- 输出:
pred
:重建的Patch序列(形状[N, L, p²×C]
,与patchify(imgs)
输出的原始Patch序列形状一致)。
3. 重建损失计算(self.forward_loss
)
- 输入:原始数据
imgs
、重建Patch序列pred
、掩码矩阵mask
。 - 处理逻辑:
- 将
imgs
转换为原始Patch序列(target
),与pred
对齐。 - 计算
pred
与target
的MSE误差,仅对被掩码的Patch(mask=1
)求平均,得到最终损失。
- 将
- 输出:被掩码Patch的平均重建损失(标量),作为模型训练的优化目标。
4. 重建结果还原(self.unpatchify(pred)
)
- 作用:将重建的Patch序列
pred
还原为与原始输入imgs
形状一致的数据([N, C, H, W]
),用于可视化重建效果(如对比原始图像与模型重建的图像)或推理阶段的结果输出。
数据流向与维度变化
为更清晰展示流程,以音频频谱图(N=2
,C=1
,H=1024
,W=128
,patch_size=16
,mask_ratio=0.75
,mask_2d=True
)为例:
- 输入数据:
imgs
形状[2, 1, 1024, 128]
。 - 编码器处理:
patch_embed
转换为Patch序列[2, 1212, 1024]
(L=101×12=1212
,embed_dim=1024
)。- 2D掩码后保留
1212×(1-0.75)=303
个Patch,latent
形状[2, 303+1, 1024]
(+1
为cls_token
)。 - 输出
mask
([2, 1212]
)和ids_restore
([2, 1212]
)。
- 解码器处理:
decoder_embed
投影为[2, 304, 512]
(decoder_embed_dim=512
)。- 填充
1212-303=909
个掩码Token,恢复顺序后序列长度为1212,经解码器处理后pred
形状[2, 1212, 256]
(p²×C=16²×1=256
)。
- 损失计算:
patchify(imgs)
得到原始Patch序列[2, 1212, 256]
,与pred
计算MSE,仅保留mask=1
的位置,输出损失标量。
- 重建结果:
unpatchify(pred)
还原为[2, 1, 1024, 128]
,与原始输入形状一致。
核心作用与设计逻辑
forward
方法是模型的“中枢”,其设计体现了MAE的自监督学习逻辑:
- 端到端流程:从原始数据输入到损失输出,整合了“Patch化-掩码-编码-解码-重建-损失”的全链路,确保训练过程可直接优化。
- 模块化协同:通过调用
forward_encoder
、forward_decoder
、forward_loss
等模块化方法,实现功能解耦,便于维护和扩展(如替换编码器/解码器结构)。 - 兼顾训练与推理:训练时返回损失用于优化,推理时返回重建结果用于可视化或评估,满足不同阶段的需求。
- 多模态适配:通过
mask_2d
参数切换1D/2D掩码,适配图像(1D序列)和音频频谱图(2D结构)等不同类型数据,体现模型的通用性。
总结
forward
方法是 MaskedAutoencoderViT
模型的核心入口,它通过协调编码器(带掩码)、解码器和损失计算模块,实现了从原始输入到重建损失的完整流程。其设计既保证了MAE“掩码重建”的自监督学习逻辑,又通过模块化和参数控制(如 mask_2d
)适配了多模态数据,为模型的训练和应用提供了统一接口。
三、模型实例化函数
代码提供了4种规模的MAE模型实例化函数,参数与标准ViT对应,仅解码器统一为“512维+8层”(平衡性能与计算量):
函数名 | 编码器配置(patch_size=16) | 解码器配置(固定) |
---|---|---|
mae_vit_small_patch16 | embed_dim=384, depth=12, num_heads=6 | decoder_embed_dim=512, depth=8 |
mae_vit_base_patch16 | embed_dim=768, depth=12, num_heads=12 | decoder_embed_dim=512, depth=8 |
mae_vit_large_patch16 | embed_dim=1024, depth=24, num_heads=16 | decoder_embed_dim=512, depth=8 |
mae_vit_huge_patch14 | embed_dim=1280, depth=32, num_heads=16 | decoder_embed_dim=512, depth=8 |
四、核心应用场景
该模型主要用于音频频谱图的自监督预训练(通过audio_exp=True
、2D掩码、重叠Patch适配),预训练后可微调用于:
- 音频分类(如环境声分类ESC-50、语音情感识别)。
- 音频检索、异常音频检测等任务。
- 也可通过
audio_exp=False
适配图像任务(如ImageNet分类)。
五、关键设计亮点
- 多模态适配:通过
audio_exp
区分音频/图像,灵活处理1/3通道数据。 - 精细掩码:2D掩码区分时间-频率维度,更贴合音频数据的物理结构。
- 灵活解码器:支持ViT/Swin解码器,Swin的窗口注意力更适合二维数据的局部相关性建模。
- 工程优化:重叠Patch、可训练位置嵌入、像素归一化等设计,提升训练稳定性和特征质量。