第二十四章:深入CLIP的“心脏”:Vision Transformer (ViT)架构全解析
vit架构
- 前言:当Transformer“厌倦”了文字,决定“开眼看世界”
- 第一章:颠覆CNN的“第一刀”—— 把图片变成“一句话”
- 1.1 革命性思想:放弃卷积,像处理单词一样处理“图像块”
- 1.2 核心操作:Patching -> Flattening -> Linear Projection (切块->展平->投影)
- 1.3 特殊的“句子开头”:[CLS] Token的引入
- 第二章:ViT的完整架构 —— 一个“朴实无华”的标准Transformer
- 2.1 加上“GPS”:位置编码 (Positional Embedding)
- 2.2 主体:一堆“原汁原味”的Transformer Encoder Block
- 2.3 “总结陈词”:只取[CLS] Token的输出用于分类
- 第三章:【硬核挑战】用PyTorch从零构建一个迷你ViT
- 3.1 实现PatchEmbedding层,完成“图片分词”
- 3.2 组装完整的ViT模型,整合所有组件
- 第四章:用我们手写的ViT,完成一次图像分类的前向传播
- CNN vs. ViT:一场“归纳偏置”的哲学思辨
- 总结与展望:视觉世界的新“范式”
前言:当Transformer“厌倦”了文字,决定“开眼看世界”
在很长一段时间里,计算机视觉(CV)领域是**卷积神经网络(CNN)**的天下。CNN通过其精巧的“卷积核”和“池化”操作,模拟了生物视觉系统的“局部感受野”和“层级抽象”特性,在图像识别等任务上取得了巨大成功。
与此同时,在自然语言处理(NLP)领域,Transformer正凭借其强大的自注意力机制,掀起一场革命,以前所未有的方式捕捉着语言中的长距离依赖关系。
一个大胆的想法在Google的研究者脑中萌生:我们能否完全抛弃CNN,用一个“纯粹”的Transformer模型,来直接处理图像?
这个想法的结晶,就是Vision Transformer (ViT)。它粗暴但极其有效地将图片“翻译”成了Transformer能“阅读”的语言,一举在多个图像识别基准上超越了顶级的CNN模型,开启了CV领域的新范式。今天,我们就来解剖这只“跨界巨兽”。
第一章:颠覆CNN的“第一刀”—— 把图片变成“一句话”
深入ViT最核心、最反直觉的创新点——如何将一张二维的图片,预处理成一个一维的“单词”序列。
1.1 革命性思想:放弃卷积,像处理单词一样处理“图像块”
ViT的作者们做了一个惊人的决定:我们不再逐个像素地、局部地去看图片。我们直接把图片像切蛋糕一样,切成一堆小方块!
每一个“小方块”,就被当作是这个“图片句子”里的一个**“单词”(视觉Token)
1.2 核心操作:Patching -> Flattening -> Linear Projection (切块->展平->投影)
这个“图片分词”的过程,分为三步:
Patching (切块):将一张输入的224x224的图片,切成14x14=196个不重叠的16x16的小图像块(Patch)。
Flattening (展平):将每一个16x16x3(3是RGB通道)的小块,都“拉”成一个长长的一维向量,长度为16163 = 768。
Linear Projection (线性投影):将这个768维的向量,通过一个可学习的线性层(Embedding层),投影成模型需要的、蕴含更丰富信息的D维向量(比如还是768维)。
经过这三步,一张图片,就从一个二维的像素网格,变成了一个由196个“视觉Token”组成的序列!它的形状是[196, 768],这和我们处理一个有196个单词的句子时的数据形状,完全一样了!
1.3 特殊的“句子开头”:[CLS] Token的引入
借鉴BERT模型的成功经验,ViT在这一串196个“视觉Token”的最前面,还人为地加入了一个额外的、可学习的特殊Token,叫做**[CLS] Token**(Classification Token)。
它的作用:就像一个“班长”,负责“总览全局”。在经过多层Transformer的计算后,所有图像块的信息会通过自注意力机制,
不断地汇聚到这个[CLS] Token上。最终,我们只需要看这个“班长”的状态,就能知道整张图片表达了什么内容。
第二章:ViT的完整架构 —— 一个“朴实无华”的标准Transformer
介绍在完成了“图片分词”后,ViT是如何直接套用一个标准的Transformer Encoder架构来完成后续处理的。
2.1 加上“GPS”:位置编码 (Positional Embedding)
和处理文本一样,直接把一堆“图像块”向量送入Transformer,会丢失它们原始的空间位置信息。因此,我们也需要为每一个Patch Embedding,加上一个可学习的“位置编码”,告诉模型这个“单词”原本在图片的哪个位置。
2.2 主体:一堆“原汁原味”的Transformer Encoder Block
加上位置编码后,这个[197, 768]的序列(196个Patch + 1个CLS),就被送入了一个由多个标准Transformer Encoder Block堆叠而成的主体结构中。
里面的计算,和我们之前学习的完全一样:层归一化 -> 多头自注意力 -> 残差连接 -> 层归一化 -> 前馈网络 -> 残差连接。
在这个过程中,每一个图像块,都在和所有其他的图像块(包括[CLS] Token)进行着复杂的“信息交流”。
2.3 “总结陈词”:只取[CLS] Token的输出用于分类
当数据流经所有Transformer Block后,我们得到一个最终的[197, 768]的输出序列。
对于图像分类任务,ViT做了一个非常简洁的设计:它直接“扔掉”了后面196个图像块的输出,只取第一个、也就是[CLS] Token对应的那个768维输出向量。
因为经过了多层全局的自注意力计算,这个[CLS] Token的最终状态,已经聚合和浓缩了整张图片的所有高级语义信息。
最后,将这个[CLS]向量,喂给一个简单的全连接分类头(MLP Head),就能得到最终的分类结果(比如“猫”的概率是99%)。
第三章:【硬核挑战】用PyTorch从零构建一个迷你ViT
3.1 实现PatchEmbedding层,完成“图片分词”
我们将实现ViT最核心的创新点——PatchEmbedding层。它负责将一张连续的图片,转换成一个离散的、携带位置信息的“视觉Token”序列,为后续的Transformer处理做好准备。
# vit_building_blocks.pyimport torch
import torch.nn as nnclass PatchEmbedding(nn.Module):"""将图片分割成块(Patch),并进行线性投影,最终加上[CLS] Token和位置编码。Args:img_size (int): 输入图片的大小 (假设H=W).patch_size (int): 每个Patch的大小.in_channels (int): 输入图片的通道数 (如RGB=3).embed_dim (int): 线性投影后的嵌入维度 (D)."""def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2# --- 核心技巧:用一个卷积层同时实现“切块”和“线性投影” ---# kernel_size=patch_size: 每个卷积核的大小等于一个patch的大小。# stride=patch_size: 卷积核每次移动的步长也等于一个patch的大小,确保了patch之间不重叠。# out_channels=embed_dim: 卷积核的数量等于我们想要的嵌入维度。# 这样,每个卷积核作用在一个patch上,就会输出一个embed_dim维的向量。self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)# --- 创建 [CLS] Token ---# 这是一个可学习的参数,形状为 [1, 1, embed_dim],1个批次, 1个token, embed_dim维self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))# --- 创建位置编码 ---# 同样是可学习的参数,我们需要 num_patches + 1 个位置编码(要加上CLS token的位置)self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))def forward(self, x):# x 形状: [batch_size, in_channels, img_size, img_size] -> e.g., [B, 3, 224, 224]# 1. 卷积投影x = self.projection(x)# x 形状: [B, embed_dim, H', W'] -> e.g., [B, 768, 14, 14]# 2. 展平# .flatten(2) 会将H'和W'维度展平# .transpose(1, 2) 将 embed_dim 和 序列长度维度交换x = x.flatten(2).transpose(1, 2)# x 形状: [B, num_patches, embed_dim] -> e.g., [B, 196, 768]# 3. 添加 [CLS] Token# 将cls_token复制batch_size份cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)# 在序列的开头拼接cls_tokenx = torch.cat((cls_tokens, x), dim=1)# x 形状: [B, num_patches + 1, embed_dim] -> e.g., [B, 197, 768]# 4. 添加位置编码x = x + self.positional_embedding# x 形状: [B, 197, 768] (保持不变)return x
【代码解读】
这段代码最巧妙的地方在于使用nn.Conv2d来实现Patching和Linear Projection。这比手动切块再送入nn.Linear要高效得多。后续的[CLS] Token拼接和位置编码相加,都是标准的Transformer预处理操作。这个模块的输出,就是一个完美的、可以被nn.TransformerEncoder直接接收的序列。
3.2 组装完整的ViT模型,整合所有组件
现在,我们将使用刚刚创建的PatchEmbedding层作为“地基”,并从PyTorch库中请来标准的nn.TransformerEncoder作为“主体结构”,再加上一个简单的nn.Linear分类头,将它们组装成一个完整的ViT模型。
# vit_model.py# (假设PatchEmbedding类定义在同一个文件或已导入)
# from vit_building_blocks import PatchEmbedding
import torch
import torch.nn as nnclass VisionTransformer(nn.Module):"""一个简化的Vision Transformer模型。"""def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768,num_heads=12, num_layers=12, mlp_dim=3072, num_classes=1000):super().__init__()# --- 1. Patch Embedding (我们自己写的模块) ---self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)# --- 2. Transformer Encoder (直接使用PyTorch官方实现) ---# 创建一个Transformer Encoder Layer的配置encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim,nhead=num_heads,dim_feedforward=mlp_dim,batch_first=True # 【重要】确保输入和输出的形状是 [B, L, D])# 使用这个配置,堆叠num_layers个层,形成完整的Encoderself.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)# --- 3. 分类头 (MLP Head) ---# 这是一个简单的全连接层,用于最终的分类self.mlp_head = nn.Linear(embed_dim, num_classes)def forward(self, x):# x 形状: [B, 3, 224, 224]# 1. 获取Patch Embeddingsx = self.patch_embed(x)# x 形状: [B, 197, 768]# 2. 通过Transformer Encoderx = self.transformer_encoder(x)# x 形状: [B, 197, 768] (保持不变)# 3. 提取 [CLS] Token的输出用于分类# x[:, 0] 表示取所有batch的第0个token (即[CLS] token)cls_token_output = x[:, 0]# cls_token_output 形状: [B, 768]# 4. 通过分类头,得到最终的Logitslogits = self.mlp_head(cls_token_output)# logits 形状: [B, num_classes] -> e.g., [B, 1000]return logits
【代码解读】
这段代码完美地展示了模块化编程的威力。我们把复杂的部分(如PatchEmbedding和TransformerEncoderLayer)都封装好,然后在顶层的VisionTransformer模型中,像搭乐高一样,将它们按顺序连接起来。forward方法中的数据流,清晰地再现了我们在理论部分讲解的ViT工作流程。
第四章:用我们手写的ViT,完成一次图像分类的前向传播
我们将实例化我们亲手打造的VisionTransformer模型,并创建一个符合输入要求的“假”图像Tensor,完成一次完整的前向传播,以此来验证我们整个模型的结构和数据流是否正确无误。
【代码实现】
# run_vit.py# (假设VisionTransformer类定义在同一个文件或已导入)
# from vit_model import VisionTransformer
import torchdef main():# --- 1. 定义模型参数 (使用ViT-Base的配置) ---img_size = 224patch_size = 16embed_dim = 768num_heads = 12num_layers = 12mlp_dim = 3072num_classes = 1000 # ImageNet的1000个类别# --- 2. 实例化我们手写的ViT模型 ---print("正在实例化Vision Transformer模型...")model = VisionTransformer(img_size=img_size,patch_size=patch_size,embed_dim=embed_dim,num_heads=num_heads,num_layers=num_layers,mlp_dim=mlp_dim,num_classes=num_classes)model.eval() # 设置为评估模式print("模型实例化完成!")# --- 3. 创建一个符合输入要求的“假”图像Batch ---batch_size = 4# 形状为 [B, C, H, W] 的随机Tensordummy_image_batch = torch.randn(batch_size, 3, img_size, img_size)print(f"\n创建了一个假的输入图像Batch,形状为: {dummy_image_batch.shape}")# --- 4. 执行一次完整的前向传播 ---print("正在执行前向传播...")with torch.no_grad():output_logits = model(dummy_image_batch)print("前向传播完成!")# --- 5. 检查输出的形状 ---print(f"\n最终模型输出 (Logits) 的形状为: {output_logits.shape}")print(f"✅ 验证成功:形状为 [{batch_size}, {num_classes}],完全符合我们的预期!")if __name__ == "__main__":main()
【代码解读与见证奇迹】
运行这段脚本,你不会看到任何炫酷的分类结果,因为我们用的是随机初始化的权重。但是,你会看到一个没有报错、流畅运行的完整流程,以及一个形状完全正确的最终输出torch.Size([4, 1000])。
这看似平淡无奇,实则意义非凡。它证明了我们亲手编写的、包含PatchEmbedding、TransformerEncoder和MLPHead在内的复杂模型,其内部所有模块的维度衔接天衣无缝。我们已经成功地搭建起了ViT这座大厦的“钢筋骨架”!有了这个骨架,后续只需要用海量数据进行训练,填充“血肉”(学习到的权重),它就能真正地“开眼看世界”了。
CNN vs. ViT:一场“归纳偏置”的哲学思辨
CNN (卷积神经网络):拥有强大的**“归纳偏置 (Inductive Bias)”。它的设计(局部连接、权重共享)天生就“假设”了图像信息具有“局部性”和“平移不变性”**。这使得它在数据量较少时,能学得更快、更好。
ViT (视觉Transformer):几乎没有归纳偏置。它对图像的结构一无所知,把所有Patch一视同仁。它完全依赖于自注意力机制,从海量数据中,自己去学习图像的空间关系。
结论:在数据量“小”的时候,CNN的“先天经验”更占优势。但在数据量“巨大”(亿级、十亿级)的时候,ViT这种“一张白纸”的模型,反而能不受“先天经验”的束缚,从数据中学习到更普适、更强大的视觉模式,从而达到更高的性能上限。
总结与展望:视觉世界的新“范式”
恭喜你!今天你已经彻底解构了计算机视觉领域的一座重要里程碑。
✨ 本章惊喜概括 ✨
你掌握了什么? | 对应的技能/工具 |
---|---|
理解了ViT的核心思想 | ✅ 将图片“分词”为Patch序列 |
洞悉了其数据处理流程 | ✅ Patching -> Flattening -> Projection |
掌握了[CLS] Token的妙用 | ✅ 作为全局信息的“聚合器” |
亲手构建了模型 | ✅ 从零实现了迷你ViT的PyTorch代码 |
理解了其哲学 | ✅ 归纳偏置与数据量的权衡 |
ViT的成功,不仅在于它本身强大的性能,更在于它统一了NLP和CV两大领域的模型范式。它证明了Transformer这个强大的架构,是通往更通用人工智能(AGI)的一条极具潜力的道路。 |