YoloV12改进策略:Block改进|TAB,融合组内自注意力(IASA)和组间交叉注意力(IRCA)|即插即用
论文信息
本文提出了一种新颖的轻量级图像超分辨率网络,称为内容感知令牌聚合网络(CATANet)。该网络旨在解决基于Transformer的方法在高空间分辨率下的计算复杂度问题。CATANet通过高效的内容感知令牌聚合模块(CATA)来捕捉长距离依赖关系,同时保持高推理速度。
- 论文连接:https://arxiv.org/pdf/2503.06896
- Github代码链接: https://github.com/EquationWalker/CATANet
创新点
- 内容感知令牌聚合模块(CATA):该模块通过聚合内容相似的令牌来减少冗余信息,提高效率。与传统方法不同,CATA在训练阶段更新令牌中心,而在推理阶段保持不变,从而加快推理速度。
- 组内自注意力(IASA)和组间交叉注意力(IRCA):这两种机制分别用于在内容感知区域内和组间进行信息交互,增强了模型对远距离依赖关系的捕捉能力。
- 轻量级设计:CATANet的设计使其适用于资源受限的环境,如移动设备,且在保持高性能的同时显著提高了处理速度。
方法
CATANet的整体架构包括三个主要模块:
- 浅层特征提取:通过卷积层将低分辨率输入图像转换为高维特征。
- 深层特征提取:使用多个残差组(RG),每个组包含CATA模块、IASA和卷积层,以提取更深层次的特征。
- 图像重建模块:将提取的特征处理后生成高分辨率图像。
在CATA模块中,令牌根据与令牌中心的相似度进行聚类,形成内容相似的令牌组。IASA和IRCA则分别在组内和组间进行自注意力计算,以捕捉更丰富的上下文信息。
组内自注意力(Intra-Group Self-Attention, IASA)和组间交叉注意力(Inter-Group Cross-Attention, IRCA)如何提升模型性能
组内自注意力(IASA)
- 细粒度长距离信息交互: IASA模块专注于内容相似的令牌组,允许这些令牌之间进行细粒度的长距离信息交互。这种机制确保了即使在局部区域内,模型也能捕捉到更广泛的上下文信息,从而提高了特征的表达能力。
- 避免信息分散: 通过将每个子组的查询(Query)与当前组和上一组的键(Key)进行交互,IASA能够有效地避免因划分子组而导致的内容相似令牌信息的分散。这种设计确保了信息的连贯性和完整性,使得模型能够更好地理解和重建图像细节。
组间交叉注意力(IRCA)
- 增强全局信息交互: IRCA机制允许不同组之间以及组与令牌中心之间进行信息交互。这种设计使得模型能够利用全局信息,从而增强了对图像整体结构和特征的理解。
- 高效性:在推理阶段,由于令牌中心是共享的,IRCA不需要额外的复杂计算,这保持了模型的高效性。通过这种方式,IRCA能够在不增加计算负担的情况下,提升模型的性能。
令牌聚合块(Token-Aggregation Block, TAB)
令牌聚合块(Token-Aggregation Block, TAB)是其核心组件,主要由以下几个部分组成:
- 内容感知令牌聚合模块(CATA): CATA模块的设计旨在聚合内容相似的令牌。通过计算令牌与共享的令牌中心之间的余弦相似度,将图像令牌分为多个内容相似的组。这种聚合方式能够有效减少冗余信息,提高模型的效率。
- 组内自注意力(IASA): IASA模块负责在内容相似的令牌组内进行细粒度的长距离信息交互。它通过将每个子组的查询与当前组和上一组的键进行交互,增强了信息捕获能力,确保了信息的连贯性和完整性。
- 组间交叉注意力(IRCA):IRCA模块允许不同组之间以及组与令牌中心之间进行信息交互。这种设计使得模型能够利用全局信息,从而增强了对图像整体结构和特征的理解。
- 卷积层: 在TAB的最后,使用一个 1 × 1 1 \times 1 1×1的卷积层进一步细化局部特征,并隐式地学习位置嵌入。
效果
实验结果表明,CATANet在多个公共超分辨率数据集(如Set5、Set14、B100等)上表现优越。与最先进的基于聚类的方法SPIN相比,CATANet在峰值信噪比(PSNR)上最大提升了0.33 dB,且推理速度几乎翻倍。这些结果证明了CATANet在轻量级图像超分辨率任务中的有效性和高效性。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from inspect import isfunctiondef exists(val):return val is not Nonedef is_empty(t):return t.nelement() == 0def expand_dim(t, dim, k):t = t.unsqueeze(dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def default(x, d):if not exists(x):return d if not isfunction(d) else d()return xdef ema(old, new, decay):if not exists(old):return newreturn old * decay + new * (1 - decay)def ema_inplace(moving_avg, new, decay):if is_empty(moving_avg):moving_avg.data.copy_(new)returnmoving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))def similarity(x, means):return torch.einsum('bld,cd->blc', x, means)def dists_and_buckets(x, means):dists = similarity(x, means)_, buckets = torch.max(dists, dim=-1)return dists, bucketsdef batched_bincount(index, num_classes, dim=-1):shape = list(index.shape)shape[dim] = num_classesout = index.new_zeros(shape)out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype))return outdef center_iter(x, means, buckets=None):b, l, d, dtype, num_tokens = *x.shape, x.dtype, means.shape[0]if not exists(buckets):_, buckets = dists_and_buckets(x, means)bins = batched_bincount(buckets, num_tokens).sum(0, keepdim=True)zero_mask = bins.long() == 0means_ = buckets.new_zeros(b, num_tokens, d, dtype=dtype)means_.scatter_add_(-2, expand_dim(buckets, -1, d), x)means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype)means = torch.where(zero_mask.unsqueeze(-1), means, means_)means = means.squeeze(0)return meansclass IASA(nn.Module):def __init__(self, dim, qk_dim, heads, group_size):super().__init__()self.heads = headsself.to_q = nn.Linear(dim, qk_dim, bias=False)self.to_k = nn.Linear(dim, qk_dim, bias=False)self.to_v = nn.Linear(dim, dim, bias=False)self.proj = nn.Linear(dim, dim, bias=False)self.group_size = group_sizedef forward(self, normed_x, idx_last, k_global, v_global):x = normed_xB, N, _ = x.shapeq, k, v = self.to_q(x), self.to_k(x), self.to_v(x)q = torch.gather(q, dim=-2, index=idx_last.expand(q.shape))k = torch.gather(k, dim=-2, index=idx_last.expand(k.shape))v = torch.gather(v, dim=-2, index=idx_last.expand(v.shape))gs = min(N, self.group_size) # group sizeng = (N + gs - 1) // gspad_n = ng * gs - Npaded_q = torch.cat((q, torch.flip(q[:, N - pad_n:N, :], dims=[-2])), dim=-2)paded_q = rearrange(paded_q, "b (ng gs) (h d) -> b ng h gs d", ng=ng, h=self.heads)paded_k = torch.cat((k, torch.flip(k[:, N - pad_n - gs:N, :], dims=[-2])), dim=-2)paded_k = paded_k.unfold(-2, 2 * gs, gs)paded_k = rearrange(paded_k, "b ng (h d) gs -> b ng h gs d", h=self.heads)paded_v = torch.cat((v, torch.flip(v[:, N - pad_n - gs:N, :], dims=[-2])), dim=-2)paded_v = paded_v.unfold(-2, 2 * gs, gs)paded_v = rearrange(paded_v, "b ng (h d) gs -> b ng h gs d", h=self.heads)out1 = F.scaled_dot_product_attention(paded_q, paded_k, paded_v)k_global = k_global.reshape(1, 1, *k_global.shape).expand(B, ng, -1, -1, -1)v_global = v_global.reshape(1, 1, *v_global.shape).expand(B, ng, -1, -1, -1)out2 = F.scaled_dot_product_attention(paded_q, k_global, v_global)out = out1 + out2out = rearrange(out, "b ng h gs d -> b (ng gs) (h d)")[:, :N, :]out = out.scatter(dim=-2, index=idx_last.expand(out.shape), src=out)out = self.proj(out)return outclass IRCA(nn.Module):def __init__(self, dim, qk_dim, heads):super().__init__()self.heads = headsself.to_k = nn.Linear(dim, qk_dim, bias=False)self.to_v = nn.Linear(dim, dim, bias=False)def forward(self, normed_x, x_means):x = normed_xif self.training:x_global = center_iter(F.normalize(x, dim=-1), F.normalize(x_means, dim=-1))else:x_global = x_meansk, v = self.to_k(x_global), self.to_v(x_global)k = rearrange(k, 'n (h dim_head)->h n dim_head', h=self.heads)v = rearrange(v, 'n (h dim_head)->h n dim_head', h=self.heads)return k, v, x_global.detach()class TAB(nn.Module):def __init__(self, dim, qk_dim, mlp_dim, heads, n_iter=3,num_tokens=8, group_size=128,ema_decay=0.999):super().__init__()self.n_iter = n_iterself.ema_decay = ema_decayself.num_tokens = num_tokensself.norm = nn.LayerNorm(dim)self.mlp = PreNorm(dim, ConvFFN(dim, mlp_dim))self.irca_attn = IRCA(dim, qk_dim, heads)self.iasa_attn = IASA(dim, qk_dim, heads, group_size)self.register_buffer('means', torch.randn(num_tokens, dim))self.register_buffer('initted', torch.tensor(False))self.conv1x1 = nn.Conv2d(dim, dim, 1, bias=False)def forward(self, x):_, _, h, w = x.shapex = rearrange(x, 'b c h w->b (h w) c')residual = xx = self.norm(x)B, N, _ = x.shapeidx_last = torch.arange(N, device=x.device).reshape(1, N).expand(B, -1)if not self.initted:pad_n = self.num_tokens - N % self.num_tokenspaded_x = torch.cat((x, torch.flip(x[:, N - pad_n:N, :], dims=[-2])), dim=-2)x_means = torch.mean(rearrange(paded_x, 'b (cnt n) c->cnt (b n) c', cnt=self.num_tokens), dim=-2).detach()else:x_means = self.means.detach()if self.training:with torch.no_grad():for _ in range(self.n_iter - 1):x_means = center_iter(F.normalize(x, dim=-1), F.normalize(x_means, dim=-1))k_global, v_global, x_means = self.irca_attn(x, x_means)with torch.no_grad():x_scores = torch.einsum('b i c,j c->b i j',F.normalize(x, dim=-1),F.normalize(x_means, dim=-1))x_belong_idx = torch.argmax(x_scores, dim=-1)idx = torch.argsort(x_belong_idx, dim=-1)idx_last = torch.gather(idx_last, dim=-1, index=idx).unsqueeze(-1)y = self.iasa_attn(x, idx_last, k_global, v_global)y = rearrange(y, 'b (h w) c->b c h w', h=h).contiguous()y = self.conv1x1(y)x = residual + rearrange(y, 'b c h w->b (h w) c')x = self.mlp(x, x_size=(h, w)) + xif self.training:with torch.no_grad():new_means = x_meansif not self.initted:self.means.data.copy_(new_means)self.initted.data.copy_(torch.tensor(True))else:ema_inplace(self.means, new_means, self.ema_decay)return rearrange(x, 'b (h w) c->b c h w', h=h)def patch_divide(x, step, ps):"""Crop image into patches.Args:x (Tensor): Input feature map of shape(b, c, h, w).step (int): Divide step.ps (int): Patch size.Returns:crop_x (Tensor): Cropped patches.nh (int): Number of patches along the horizontal direction.nw (int): Number of patches along the vertical direction."""b, c, h, w = x.size()if h == ps and w == ps:step = pscrop_x = []nh = 0for i in range(0, h + step - ps, step):top = idown = i + psif down > h:top = h - psdown = hnh += 1for j in range(0, w + step - ps, step):left = jright = j + psif right > w:left = w - psright = wcrop_x.append(x[:, :, top:down, left:right])nw = len(crop_x) // nhcrop_x = torch.stack(crop_x, dim=0) # (n, b, c, ps, ps)crop_x = crop_x.permute(1, 0, 2, 3, 4).contiguous() # (b, n, c, ps, ps)return crop_x, nh, nwdef patch_reverse(crop_x, x, step, ps):"""Reverse patches into image.Args:crop_x (Tensor): Cropped patches.x (Tensor): Feature map of shape(b, c, h, w).step (int): Divide step.ps (int): Patch size.Returns:output (Tensor): Reversed image."""b, c, h, w = x.size()output = torch.zeros_like(x)index = 0for i in range(0, h + step - ps, step):top = idown = i + psif down > h:top = h - psdown = hfor j in range(0, w + step - ps, step):left = jright = j + psif right > w:left = w - psright = woutput[:, :, top:down, left:right] += crop_x[:, index]index += 1for i in range(step, h + step - ps, step):top = idown = i + ps - stepif top + ps > h:top = h - psoutput[:, :, top:down, :] /= 2for j in range(step, w + step - ps, step):left = jright = j + ps - stepif left + ps > w:left = w - psoutput[:, :, :, left:right] /= 2return outputclass PreNorm(nn.Module):"""Normalization layer.Args:dim (int): Base channels.fn (Module): Module after normalization."""def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class dwconv(nn.Module):def __init__(self, hidden_features, kernel_size=5):super(dwconv, self).__init__()self.depthwise_conv = nn.Sequential(nn.Conv2d(hidden_features, hidden_features, kernel_size=kernel_size, stride=1,padding=(kernel_size - 1) // 2, dilation=1,groups=hidden_features), nn.GELU())self.hidden_features = hidden_featuresdef forward(self, x, x_size):x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous() # b Ph*Pw cx = self.depthwise_conv(x)x = x.flatten(2).transpose(1, 2).contiguous()return xclass ConvFFN(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, kernel_size=5, act_layer=nn.GELU):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.dwconv = dwconv(hidden_features=hidden_features, kernel_size=kernel_size)self.fc2 = nn.Linear(hidden_features, out_features)def forward(self, x, x_size):x = self.fc1(x)x = self.act(x)x = x + self.dwconv(x, x_size)x = self.fc2(x)return xif __name__ == "__main__":# 定义输入张量大小(Batch、Channel、Height、Wight)B, C, H, W = 16, 64, 40, 40input_tensor = torch.randn(B,C,H,W) # 随机生成输入张量dim=C# 创建 DynamicTanh 实例block = TAB(dim,16,4,8)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")block = block.to(device)print(block)input_tensor = input_tensor.to(device)# 执行前向传播output = block(input_tensor)# 打印输入和输出的形状print(f"Input: {input_tensor.shape}")print(f"Output: {output.shape}")
TAB类代码详解
TAB类是一个基于Transformer的注意力模块,全称为Token Attention Block。下面我将详细解析这个类的实现:
类初始化
class TAB(nn.Module):def __init__(self, dim, qk_dim, mlp_dim, heads, n_iter=3,num_tokens=8, group_size=128,ema_decay=0.999):super().__init__()
参数说明:
dim
: 输入特征的维度qk_dim
: Query和Key的维度mlp_dim
: MLP层的隐藏层维度heads
: 注意力头的数量n_iter
: 迭代次数(用于聚类中心更新)num_tokens
: 聚类中心/token的数量group_size
: 分组注意力时的组大小ema_decay
: 指数移动平均的衰减系数
主要组件
-
初始化参数和层:
self.n_iter = n_iter self.ema_decay = ema_decay self.num_tokens = num_tokens
-
归一化层:
self.norm = nn.LayerNorm(dim)
-
MLP层:
self.mlp = PreNorm(dim, ConvFFN(dim, mlp_dim))
-
注意力机制:
- IRCA (Iterative Refinement Cluster Attention): 迭代 refinement 聚类注意力
- IASA (Iterative Adaptive Group Shuffle Attention): 迭代自适应分组洗牌注意力
self.irca_attn = IRCA(dim, qk_dim, heads) self.iasa_attn = IASA(dim, qk_dim, heads, group_size)
-
可学习的聚类中心:
self.register_buffer('means', torch.randn(num_tokens, dim)) self.register_buffer('initted', torch.tensor(False))
-
1x1卷积:
self.conv1x1 = nn.Conv2d(dim, dim, 1, bias=False)
前向传播
def forward(self, x):# 获取输入形状_, _, h, w = x.shape# 将图像展平为序列x = rearrange(x, 'b c h w->b (h w) c')residual = x # 保存残差连接# 归一化x = self.norm(x)B, N, _ = x.shape# 初始化索引idx_last = torch.arange(N, device=x.device).reshape(1, N).expand(B, -1)# 初始化聚类中心if not self.initted:pad_n = self.num_tokens - N % self.num_tokenspaded_x = torch.cat((x, torch.flip(x[:, N - pad_n:N, :], dims=[-2])), dim=-2)x_means = torch.mean(rearrange(paded_x, 'b (cnt n) c->cnt (b n) c', cnt=self.num_tokens), dim=-2).detach()else:x_means = self.means.detach()# 训练时迭代更新聚类中心if self.training:with torch.no_grad():for _ in range(self.n_iter - 1):x_means = center_iter(F.normalize(x, dim=-1), F.normalize(x_means, dim=-1))# 计算全局key和valuek_global, v_global, x_means = self.irca_attn(x, x_means)# 计算token归属with torch.no_grad():x_scores = torch.einsum('b i c,j c->b i j',F.normalize(x, dim=-1),F.normalize(x_means, dim=-1))x_belong_idx = torch.argmax(x_scores, dim=-1)# 对索引进行排序idx = torch.argsort(x_belong_idx, dim=-1)idx_last = torch.gather(idx_last, dim=-1, index=idx).unsqueeze(-1)# 应用IASA注意力y = self.iasa_attn(x, idx_last, k_global, v_global)# 恢复图像形状y = rearrange(y, 'b (h w) c->b c h w', h=h).contiguous()# 应用1x1卷积y = self.conv1x1(y)# 残差连接x = residual + rearrange(y, 'b c h w->b (h w) c')# 应用MLPx = self.mlp(x, x_size=(h, w)) + x# 训练时更新聚类中心if self.training:with torch.no_grad():new_means = x_meansif not self.initted:self.means.data.copy_(new_means)self.initted.data.copy_(torch.tensor(True))else:ema_inplace(self.means, new_means, self.ema_decay)return rearrange(x, 'b (h w) c->b c h w', h=h)
关键组件详解
IRCA (Iterative Refinement Cluster Attention)
class IRCA(nn.Module):def __init__(self, dim, qk_dim, heads):super().__init__()self.heads = headsself.to_k = nn.Linear(dim, qk_dim, bias=False)self.to_v = nn.Linear(dim, dim, bias=False)def forward(self, normed_x, x_means):x = normed_xif self.training:# 训练时迭代更新聚类中心x_global = center_iter(F.normalize(x, dim=-1), F.normalize(x_means, dim=-1))else:# 测试时直接使用存储的聚类中心x_global = x_means# 计算全局key和valuek, v = self.to_k(x_global), self.to_v(x_global)k = rearrange(k, 'n (h dim_head)->h n dim_head', h=self.heads)v = rearrange(v, 'n (h dim_head)->h n dim_head', h=self.heads)return k, v, x_global.detach()
IASA (Iterative Adaptive Group Shuffle Attention)
class IASA(nn.Module):def __init__(self, dim, qk_dim, heads, group_size):super().__init__()self.heads = headsself.to_q = nn.Linear(dim, qk_dim, bias=False)self.to_k = nn.Linear(dim, qk_dim, bias=False)self.to_v = nn.Linear(dim, dim, bias=False)self.proj = nn.Linear(dim, dim, bias=False)self.group_size = group_sizedef forward(self, normed_x, idx_last, k_global, v_global):x = normed_xB, N, _ = x.shape# 计算query, key, valueq, k, v = self.to_q(x), self.to_k(x), self.to_v(x)# 收集特定位置的tokenq = torch.gather(q, dim=-2, index=idx_last.expand(q.shape))k = torch.gather(k, dim=-2, index=idx_last.expand(k.shape))v = torch.gather(v, dim=-2, index=idx_last.expand(v.shape))# 分组处理gs = min(N, self.group_size) # 组大小ng = (N + gs - 1) // gs # 组数pad_n = ng * gs - N # 填充大小# 对query进行填充和重组paded_q = torch.cat((q, torch.flip(q[:, N - pad_n:N, :], dims=[-2])), dim=-2)paded_q = rearrange(paded_q, "b (ng gs) (h d) -> b ng h gs d", ng=ng, h=self.heads)# 对key进行填充和重组paded_k = torch.cat((k, torch.flip(k[:, N - pad_n - gs:N, :], dims=[-2])), dim=-2)paded_k = paded_k.unfold(-2, 2 * gs, gs)paded_k = rearrange(paded_k, "b ng (h d) gs -> b ng h gs d", h=self.heads)# 对value进行填充和重组paded_v = torch.cat((v, torch.flip(v[:, N - pad_n - gs:N, :], dims=[-2])), dim=-2)paded_v = paded_v.unfold(-2, 2 * gs, gs)paded_v = rearrange(paded_v, "b ng (h d) gs -> b ng h gs d", h=self.heads)# 计算组内注意力out1 = F.scaled_dot_product_attention(paded_q, paded_k, paded_v)# 扩展全局key和valuek_global = k_global.reshape(1, 1, *k_global.shape).expand(B, ng, -1, -1, -1)v_global = v_global.reshape(1, 1, *v_global.shape).expand(B, ng, -1, -1, -1)# 计算全局注意力out2 = F.scaled_dot_product_attention(paded_q, k_global, v_global)# 合并结果out = out1 + out2# 恢复形状out = rearrange(out, "b ng h gs d -> b (ng gs) (h d)")[:, :N, :]# 散射回原位置out = out.scatter(dim=-2, index=idx_last.expand(out.shape), src=out)# 投影out = self.proj(out)return out
辅助函数
- center_iter: 迭代更新聚类中心
- dists_and_buckets: 计算距离和归属的bucket
- batched_bincount: 批量计算bin计数
- patch_divide 和 patch_reverse: 图像分块和恢复
设计特点
-
双注意力机制:
- IRCA负责学习全局的聚类中心表示
- IASA负责局部和全局的注意力计算
-
迭代优化:
- 聚类中心通过多次迭代逐步优化
- 使用EMA(指数移动平均)在训练过程中平滑更新聚类中心
-
分组处理:
- 将特征分成多个组进行并行处理
- 每组内部进行自注意力计算
-
残差连接:
- 保持了Transformer的残差连接设计
- 有利于深层网络的训练
-
动态更新:
- 聚类中心在训练过程中动态更新
- 测试时固定聚类中心
这个模块结合了聚类思想和Transformer注意力机制,能够有效地处理图像特征,同时保持计算效率。
改进方法
测试结果
总结
完整代码: