当前位置: 首页 > news >正文

ViT论文及代码解读

CV


文章目录

  • CV
  • 基本信息
  • 1 摘要
  • 2 模型结构
  • 3 模型详解
    • 3.1 预处理模块
    • 3.1 多层Transformer模块
    • 3.3 分类模块
  • 4 预训练与结果
    • 4.1 ViT更需要预训练
    • 4.2 ViT模型更容易泛化到下游任务
    • 4.3 与SOTA对比
  • 5 总结与思考
  • 参考


基本信息

论文:An Image Is Worth 16*16 Words:Transformers For Image Recognition at scale

时间:2021年

发表于:ICLR

github源码:https://github.com/zlove-summer/Vision-Transformer-pytorch

论文链接:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

1 摘要

transformer于2017年的Attention is all your need提出,一般用于自然语言处理任务(序列任务),因为其相较于循环神经网络RNN能够并行化处理而广泛使用。本文最大的创新点就是将transformer应用于图像分类的cv任务,证明在cv领域使用Transformer依然可以获得很好的性能,启发了后面基于transformer的目标检测和语义分割等网络。

图像是三维矩阵结构,如何将其变为序列是一个问题。在本文中,将图像使用卷积进行分块(14*14=196),再每一块进行展平处理变成序列,然后将这196个序列添加位置编码和cls token,再输入多层Transformer结构中。最后将cls tooken取出来通过一个MLP(多层感知机)用于分类。图一是ViT的整体模型结构。
在这里插入图片描述
图1 ViT整体结构图

2 模型结构

图二[1]是ViT模型的基本结构。可以看出整体模型还是非常简洁的,我将其分为预处理、Transformer模块和分类模块。
在这里插入图片描述
图2 ViT模型的详细结构

和很多论文一样,ViT也按照模型的深度分为了多种版本,如图3所示。其中Layers是Transformer模块的个数,Hidden size是分块时卷积升维后的通道数量,也是每一个块的序列长度。
在这里插入图片描述
图3 按照模型深度和参数量的多版本的ViT

3 模型详解

本节将按照一张图片输入ViT模型后,经过的变换,详解ViT的代码。

3.1 预处理模块

预处理模块的结构如图4所示。处理流程如下:
1、一张2242243的图片,通过一个卷积核大小为1616、步长为16、输出通道为768的卷积,得到1414768的输出。
2、14
14768的输出,将其按照宽高进行Flatten,其shape变成196768,表示为196个序列,每个序列长度为768。
3、在196768的数据上,cat一个1768的分类token在最前面。则shape变成197768。我们设这个197768的矩阵为A。
4、设置一个1197768的Position Embedding,对应值相加至A。分类token和Position Embedding都需要nn.init.trunc_normal_进行初始化。
在这里插入图片描述
图4 ViT的预处理模块

  • PatchEmbed模型(图4的Patch Embedding,用于得到196*768的序列)
class PatchEmbed(nn.Module):"""2D Image to Patch Embedding,二维图像patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):super().__init__()img_size = (img_size, img_size)  # 图片尺寸224*224patch_size = (patch_size, patch_size)  #下采样倍数,一个grid cell包含了16*16的图片信息self.img_size = img_sizeself.patch_size = patch_size# grid_size是经过patchembed后的特征层的尺寸self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1] #path个数 14*14=196# 通过一个卷积,完成patchEmbedself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size,stride=patch_size)# 如果使用了norm层,如BatchNorm2d,将通道数传入,以进行归一化,否则进行恒等映射self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
​def forward(self, x):B, C, H, W = x.shape  #batch,channels,heigth,weigth# 输入图片的尺寸要满足既定的尺寸assert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."# proj: [B, C, H, W] -> [B, C, H,W] , [B,3,224,224]-> [B,768,14,14]# flatten: [B, C, H, W] -> [B, C, HW] , [B,768,14,14]-> [B,768,196]# transpose: [B, C, HW] -> [B, HW, C] , [B,768,196]-> [B,196,768]x = self.proj(x).flatten(2).transpose(1, 2)x = self.norm(x)return x

3.1 多层Transformer模块

将预处理得到的输出,经过多层Transformer模块,用于特征的提取,如图2所示。多层Transformer模块,顾名思义就是多次叠加Transformer模块。因此本节主要讲解Transformer,模块结构如图5[1]。
在这里插入图片描述
图5 Transformer模块结构

Transformer模块主要有两个部分,一个是Muti-head Attention,另一个是MLP。

  • Muti-head Attention

Transformer与CNN最大的不同就是这个自注意力结构,它能够使得网络看到全局的信息,而不是CNN的局部感受野。self-attention的计算方式如图6。

在这里插入图片描述
图6 Self-attention 计算

关于Self-attention的详细计算可以看这篇文章 Vision Transfromer:解读self-attention 。其实就是,序列a,经过三个不同的矩阵,得到 q, k, v,q与k点乘得到相关系数 a,对所有 a 缩放并softmax归一化,再分别乘以v进行加权,得到输出序列b。
在代码实现时,矩阵相乘可以用一个线性层得到(nn.Linear),代码实现时,将线性层的输出通道为原来3倍,可以一次性得到 q, k, v。这里的head变成muti-head操作是直接将维度(如768)除以num-head(如12)得到多个head。

class Attention(nn.Module):"""muti-head attention模块,也是transformer最主要的操作"""def __init__(self,dim,   # 输入token的dim,768num_heads=8, #muti-head的head个数,实例化时base尺寸的vit默认为12qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_heads  #平均每个head的维度self.scale = qk_scale or head_dim ** -0.5  #进行query操作时,缩放因子# qkv矩阵相乘操作,dim * 3使得一次性进行qkv操作self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim)  #一个卷积层self.proj_drop = nn.Dropout(proj_drop_ratio)
​def forward(self, x):# [batch_size, num_patches + 1, total_embed_dim] 如 [bactn,197,768]B, N, C = x.shape  # N:197 , C:768# qkv进行注意力操作,reshape进行muti-head的维度分配,permute维度调换以便后续操作# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] 如 [b,197,2304]# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] 如 [b,197,3,12,64],在这一步中实现了muti-head操作# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# qkv的维度相同,[batch_size, num_heads, num_patches + 1, embed_dim_per_head]q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]attn = (q @ k.transpose(-2, -1)) * self.scale  #矩阵相乘操作attn = attn.softmax(dim=-1) #每一path进行softmax操作attn = self.attn_drop(attn)# [b,12,197,197]@[b,12,197,64] -> [b,12,197,64]# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]# 维度交换 transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]# reshape: -> [batch_size, num_patches + 1, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)  #经过一层卷积x = self.proj_drop(x)  #Dropoutreturn x
  • MLP

MLP就是一个两层感知机,如图5右侧所示。隐藏层通道数升维为原来4倍。

class Mlp(nn.Module):"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None,act_layer=nn.GELU,  # GELU是更加平滑的reludrop=0.):super().__init__()out_features = out_features or in_features  #如果out_features不存在,则为in_featureshidden_features = hidden_features or in_features #如果hidden_features不存在,则为in_features。论文中hidden_features升维为原来的4倍self.fc1 = nn.Linear(in_features, hidden_features) #fc层1self.act = act_layer() #激活self.fc2 = nn.Linear(hidden_features, out_features)  #fc层2self.drop = nn.Dropout(drop)
​def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x
  • Transformer基本模块

由Self-attention和MLP可以组合成Transformer的基本模块。Transformer的基本模块还使用了残差连接结构。

class Block(nn.Module):"""基本的Transformer模块"""def __init__(self,dim,num_heads, mlp_ratio=4.,qkv_bias=False, qk_scale=None, drop_ratio=0.,attn_drop_ratio=0., drop_path_ratio=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super(Block, self).__init__()self.norm1 = norm_layer(dim)  #norm层self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here# 代码使用了DropPath,而不是原版的dropoutself.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim) #norm层mlp_hidden_dim = int(dim * mlp_ratio)  #隐藏层维度扩张后的通道数# 多层感知机self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
​def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))  # attention后残差连接x = x + self.drop_path(selfpython'.mlp(self.norm2(x)))   # mlp后残差连接return x

3.3 分类模块

分类头很简单,就是取特征层如197768的第一个向量,即1768,再对此进行线性全连接层进行多分类即可。

# self.num_features=768
# num_classes为分类任务的类别数量
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

4 预训练与结果

4.1 ViT更需要预训练

ViT的模型整体参数量是较大的,一个ViT-base的预训练权重就高达400M,相较于MobileNet-v2的13M和ResNet34的85M,超出较多。所以,ViT模型相较于CNN网络更加需要大数据集的预训练。文中做了一个实验,使用不同规模的ImageNet和JFT数据集,进行预训练,比较其与CNN模型的性能。如图7所示。
在这里插入图片描述
图7 不同数据量的预训练对结果的影响对比

在数据量较小时,无论是在ImageNet还是JFT数据集,BiT(以ResNet为骨干的CNN模型)准确率相对更高,但是当数据集量增大到一定程度时,ViT模型略优于CNN模型。所以,ViT模型更需要大数据集进行预训练,以提高模型的表征。

4.2 ViT模型更容易泛化到下游任务

我们知道,对于CNN网络,即使有预训练权重,当使用这个网络泛化到其他下游任务时,也需要训练较长时间才能达到较好的结果。但是,对于ViT模型来说,当拥有ViT的预训练权重时,只需要训练几个epoch既可以拥有很好的性能。

我曾做过实验,无论是使用小模型和轻量化模型AlexNet、MobileNetv2,还是使用大模型ResNet50,要达到较好预测,都要训练30-50epoch甚至更高。而使用ViT模型仅需要2-3个epoch便可达到更优秀的性能。这部分实验的文章稍后会写。

在文章关于此部分的实验结果如图8所示,可以看出训练7个epoch时,ViT类的模型相较于CNN模型,效果更好。
在这里插入图片描述
图8 模型结果的详细信息

4.3 与SOTA对比

图9是ViT模型与SOTA模型在多个任务的对比,可以看出,在各个任务下,ViT模型都表现更好,当然因为SOTA模型也很优秀所以没有高太多。
在这里插入图片描述
图9 ViT与SOTA对比

5 总结与思考

2021年,一个Transformer引爆的CV圈,从这篇ViT论文中,还是可以看出很多与CNN不足之处,比如需要更大规模的预训练,模型的参数量过大,相较于CNN来说效果提升不够明显等等。但是这篇文章我觉得更大的贡献是将人们的视野扩宽,以前人们做cv任务一般就是在CNN网络上进行各种结构修改,现在ViT给了大家全新的视野。

此外,由于Transformer也能处理cv任务了,那么对于多模态、视觉语言融合等方向就有了更多成功的可能。经过又一年学术界的探索,现在基于Transformer的目标检测和语义分割等等CV主流任务,在效果上已经较多地超过基于CNN的模型了。关于Transformer参数量过大的问题,也有研究者在研究轻量级和稀疏的Transformer模型,如MobileViT、SepViT。

相信随着学术界的发展,工业界也会更多地将ViT模型落地。

参考

[MobileViT] https://mp.weixin.qq.com/s/ckT9XhC8e2ugkolckpY58g
[SepViT] https://zhuanlan.zhihu.com/p/508702384

http://www.xdnf.cn/news/132679.html

相关文章:

  • synchronization
  • 八大排序——冒泡排序/归并排序
  • C++经典知识网页保存
  • 前端开发实用技巧:封装通用下载导出文件或图片方法
  • 2025年深度学习模型发展全景透视(基于前沿技术突破与开源生态演进的交叉分析)
  • 39个常用的AI指令,笔尖Ai写作、DeepSeek、腾讯元宝、豆包、Kimi等都能用
  • 制作一个简单的操作系统10
  • Android开发,实现底部弹出菜单
  • GStreamer 简明教程(十一):插件开发,以一个音频生成(Audio Source)插件为例
  • ‌Linux trap 命令详解
  • report builder问题
  • springboot3 声明式 HTTP 接口
  • JUC多线程:读写锁
  • 【高频考点精讲】前端构建工具对比:Webpack、Vite、Rollup和Parcel
  • 淘宝 /天猫/1688|京东API 常用接口列表与申请方式解析
  • P12167 [蓝桥杯 2025 省 C/Python A] 倒水
  • 对接金蝶获取接口授权代码
  • 第3讲、大模型如何理解和表示单词:词嵌入向量原理详解
  • Blender好用的插件推荐汇总
  • 电脑温度怎么看 查看CPU温度的方法
  • Golang | 位运算
  • DELPHI实现dbTreeView的节点拖动并更新
  • 为什么说美颜SDK动态贴纸才是直播、短视频平台的下一个爆点?看完你就懂了!
  • 连续帧点云目标检测结果展示,python实现
  • 这个免费的AI插件,居然让我5分钟看完2小时的YouTube视频!
  • 大麦项目pro版本来袭!扫平面试中的一切疑难杂症!
  • 视频丨Google 最新 AI 眼镜原型曝光:轻量 XR+情境感知 AI 打造下一代计算平台
  • 【C语言练习】002. 理解C语言的基本语法结构
  • 存储新势力:助力DeepSeek一体机
  • GIT下载步骤