当前位置: 首页 > news >正文

【IEEE 2025】低光增强KANT(使用KAN代替MLP)----论文详解与代码解析

【IEEE 2025】本文参考论文Enhancing Low-Light Images with Kolmogorov–Arnold Networks in Transformer Attention
虽然不是顶刊,但是有值得学习的地方
论文地址:arxiv
源码地址:github

文章目录

  • Part1 --- 论文精读
  • Part2 --- 代码详解
    • 形状追踪代码 (将原代码的n_features 从31修改为32)


Part1 — 论文精读

该论文提出了一种名为 KAN-T 的新型 Transformer 网络,用于低光图像增强 (LLIE)。其核心创新在于引入了一种受 Kolmogorov-Arnold 表示定理启发的 Transformer 注意力机制。
在这里插入图片描述

1. 整体框架 (Overall Framework)

KAN-T 采用了一个 3 级编码器-解码器结构。

  • 输入处理与编码: 输入图像首先通过一个 1 × 1 1 \times 1 1×1 卷积层进行特征扩展,从 H × W × 3 H \times W \times 3 H×W×3 扩展到 H × W × C H \times W \times C H×W×C。随后,图像被送入编码器,该编码器包含不同分辨率级别的 Transformer 模块 ( H × W × C H \times W \times C H×W×C, H 2 × W 2 × 2 C \frac{H}{2} \times \frac{W}{2} \times 2C 2H×2W×2C, 以及 H 4 × W 4 × 4 C \frac{H}{4} \times \frac{W}{4} \times 4C 4H×4W×4C)。编码器的目标是将输入图像转换为包含关键特征的抽象内部表示。
  • 瓶颈层: 编码后的特征图被下采样至 H 8 × W 8 × 8 C \frac{H}{8} \times \frac{W}{8} \times 8C 8H×8W×8C,并通过 KAN-T 的瓶颈层,该瓶颈层利用四个顺序排列的 Transformer 模块来增强内部特征表示。
  • 解码与输出: 内部表示随后进入解码过程,该过程由一系列 Transformer 模块在不同级别组成,与编码器对称排列。最终的 H × W × C H \times W \times C H×W×C 特征图经过卷积操作以减少通道数,生成 H × W × 3 H \times W \times 3 H×W×3 的输出图像。
  • 跳跃连接: KAN-T 在相应的编码器-解码器级别采用跳跃连接,以帮助保留细节和丰富特征。

2. Transformer 模块 (Transformer Block)

Transformer 模块是 KAN-T 的主要构建单元,因其执行高级特征处理的能力而被使用。

  • 组成: 该模块由一个 Kolmogorov-Arnold 多头自注意力 (KAN-MSA) 模块、一个前馈网络 (FFN) 和两个层归一化 (LN) 操作组成,同时在自注意力和特征提取两个阶段之间采用残差连接。
  • 处理流程:
    1. 自注意力阶段: 输入特征图 F i n F_{in} Fin 经过层归一化后,由 KAN-MSA 处理,然后与原始输入 F i n F_{in} Fin 进行残差连接,得到中间特征图 F ^ \hat{F} F^。数学表达式为:
      F ^ = KAN-MSA ( LN ( F i n ) ) + F i n \hat{F} = \text{KAN-MSA}(\text{LN}(F_{in})) + F_{in} F^=KAN-MSA(LN(Fin))+Fin
    2. 特征提取阶段: 中间特征图 F ^ \hat{F} F^ 经过层归一化后,由 FFN 处理,再与 F ^ \hat{F} F^ 进行残差连接,得到输出特征图 F o u t F_{out} Fout。数学表达式为:
      F o u t = FFN ( LN ( F ^ ) ) + F ^ F_{out} = \text{FFN}(\text{LN}(\hat{F})) + \hat{F} Fout=FFN(LN(F^))+F^

3. Kolmogorov-Arnold 网络多头自注意力 (KAN-MSA)
在这里插入图片描述

这是该方法的核心创新点。

  • 标准 MSA 的局限性: 标准多头自注意力 (MSA) 模块利用全连接 (fc) 层来获取查询 (Q)、键 (K) 和值 (V) 分量。虽然 fc 层可以联合处理整个多变量输入来建模复杂关系,但它们可能无法有效捕获单个通道内的单变量关系,并且由于参数数量庞大(尤其对于高维输入)而计算量大。
  • KAN-MSA 原理: 为了克服这些限制,研究者引入了一种基于 KAN 的 MSA 机制,其灵感来源于 Kolmogorov-Arnold 表示定理。该定理指出,任何多变量连续函数都可以表示为连续单变量函数和加法的叠加。新方法还融入了可学习非线性的方面。
  • KAN-MSA 处理流程:
    1. 多变量分解 (通道拆分): 给定输入特征图 F i n ∈ R H × W × C F_{in} \in R^{H \times W \times C} FinRH×W×C,首先执行通道拆分,将其分解为 F 1 , F 2 , . . . , F C F_1, F_2, ..., F_C F1,F2,...,FC,其中每个 F i ∈ R H × W × 1 F_i \in R^{H \times W \times 1} FiRH×W×1。这使得模型能够捕获数据中更复杂和特定的模式。
    2. 单变量处理与可学习非线性: 对于每个通道 i i i F i F_i Fi 通过一个包含三个全连接层序列进行处理,每个层后都有非线性激活函数 Φ j i \Phi_j^i Φji。通过使用三个顺序的 fc 层,模型可以在激活过程中激活或停用某些神经元,从而确保可学习的非线性。
      h 1 i = Φ i 1 ( W i 1 F i + b i 1 ) h_1^i = \Phi_i^1(W_i^1 F_i + b_i^1) h1i=Φi1(Wi1Fi+bi1) h i 2 = Φ i 2 ( W i 2 h i 1 + b i 2 ) h_i^2 = \Phi_i^2(W_i^2 h_i^1 + b_i^2) hi2=Φi2(Wi2hi1+bi2) h i 3 = Φ i 3 ( W i 3 h i 2 + b i 3 ) , h i 3 ∈ R H × W × 3 h_i^3 = \Phi_i^3(W_i^3 h_i^2 + b_i^3), \quad h_i^3 \in \mathbb{R}^{H \times W \times 3} hi3=Φi3(Wi3hi2+bi3),hi3RH×W×3
    3. 合并与 QKV 生成: 单变量处理的结果在通道维度上进行拼接,得到 F o u t ∈ R H × W × 3 C F_{out} \in \mathbb{R}^{H \times W \times 3C} FoutRH×W×3C,然后将其三向拆分以获得 Q、K、 V ∈ R H × W × C V \in \mathbb{R}^{H \times W \times C} VRH×W×C
    4. 自注意力计算: Q、K、V 被重塑为 H W × C HW \times C HW×C,并用于生成自注意力特征图 F o u t F_{out} Fout
      F o u t = V × softmax ( K Q T T ) F_{out} = V \times \text{softmax}(\frac{K Q^T}{\mathcal{T}}) Fout=V×softmax(TKQT)
      其中 T \mathcal{T} T 是一个可学习的参数,用于平衡注意力分数。 F o u t F_{out} Fout 随后被重塑回 H × W × C H \times W \times C H×W×C

4. 前馈网络 (Feed-Forward Network, FFN)

FFN 是 Transformer 模块的另一个关键组成部分,它使用自注意力特征图进行深度特征提取。

  • 结构: 它采用三重卷积设置,并使用高斯误差线性单元 (GELU) 激活函数 ( ψ \psi ψ)。
  • 处理流程: 给定输入特征图 F i n ∈ R H × W × C F_{in} \in \mathbb{R}^{H \times W \times C} FinRH×W×C,其计算公式为:
    F o u t = conv1 × 1 ( ψ conv3 × 3 ( ψ conv1 × 1 ( F i n ) ) ) F_{out} = \text{conv1} \times \text{1}(\psi \text{conv3} \times \text{3}(\psi \text{conv1} \times \text{1}(F_{in}))) Fout=conv1×1(ψconv3×3(ψconv1×1(Fin)))
    其中,第一个 c o n v 1 × 1 conv1 \times 1 conv1×1 将特征图扩展到 H × W × 4 C H \times W \times 4C H×W×4C 以帮助发现新模式; c o n v 3 × 3 conv3 \times 3 conv3×3 通过增加核大小执行高分辨率特征提取;最后一个 c o n v 1 × 1 conv1 \times 1 conv1×1 将特征图压缩回原始维度 H × W × C H \times W \times C H×W×C

5. 损失函数 (Loss Function)

为了实现精确重建,采用了一个复合损失函数 L \mathcal{L} L。该混合损失函数集成了多个分量以解决图像质量的各个方面,包括像素级准确性、结构完整性和感知保真度。

  • 总体损失:
    L = L M A E + α ⋅ L M S − S S I M + β ⋅ L P e r c \mathcal{L} = \mathcal{L}_{MAE} + \alpha \cdot \mathcal{L}_{MS-SSIM} + \beta \cdot \mathcal{L}_{Perc} L=LMAE+αLMSSSIM+βLPerc
    其中 α \alpha α β \beta β 是平衡每个损失分量贡献的超参数。
  • 各分量原理:
    • 平均绝对误差损失 ( L M A E \mathcal{L}_{MAE} LMAE): 作为主要项,它捕获预测图像 I ^ \hat{I} I^ 和真实图像 I G T \mathcal{I}_{GT} IGT 之间的平均差异。
      L M A E ( x , y ) = 1 N ∑ x , y ∣ ∣ I ^ ( x , y ) − I G T ( x , y ) ∣ ∣ 1 \mathcal{L}_{MAE}(x,y) = \frac{1}{N} \sum_{x,y} ||\hat{I}(x,y) - \mathcal{I}_{GT}(x,y)||_1 LMAE(x,y)=N1x,y∣∣I^(x,y)IGT(x,y)1
    • 多尺度结构相似性指数度量损失 ( L M S − S S I M \mathcal{L}_{MS-SSIM} LMSSSIM): 评估预测图像和真实图像在多个尺度上的结构相似性。它通过评估结构失真(尤其是在低光等挑战性条件下)来捕获对保持图像结构完整性至关重要的高级特征。
    • 感知损失 ( L P e r c \mathcal{L}_{Perc} LPerc): 利用预训练的 VGG-19 网络 ( Ψ \Psi Ψ) 来引入特征级监督。该损失测量预测图像和真实图像的高级特征表示之间的差异,有助于学习有意义的内部表示。
      L P e r c ( x , y ) = 1 N ∑ x , y ∣ ∣ Ψ ( I ^ ( x , y ) ) − Ψ ( I G T ( x , y ) ) ∣ ∣ 1 \mathcal{L}_{Perc}(x,y) = \frac{1}{N} \sum_{x,y} ||\Psi(\hat{I}(x,y)) - \Psi(\mathcal{I}_{GT}(x,y))||_1 LPerc(x,y)=N1x,y∣∣Ψ(I^(x,y))Ψ(IGT(x,y))1
      通过集成这三个损失分量,该混合损失函数有效地平衡了像素级准确性、结构一致性和感知质量。
      在这里插入图片描述

总而言之,该方法的核心原理在于利用 Kolmogorov-Arnold 表示定理的思想改进 Transformer 中的多头自注意力机制,通过将多变量函数分解为单变量函数和线性组合,并引入可学习的非线性激活函数,从而在低光图像增强任务中实现更灵活、更有效的特征表示和上下文信息捕获。结合精心设计的编码器-解码器架构和复合损失函数,旨在实现卓越的性能。


Part2 — 代码详解

形状追踪代码 (将原代码的n_features 从31修改为32)

现在,我们编写代码来实例化 KANT 模型,并在其 forward 方法的每个重要步骤打印张量形状。为了便于演示,我将稍微修改 KANT 类,以便在其前向传播中更容易打印形状。

import torch
import torch.nn as nn
import numbers
import torch.nn.functional as F
from einops import rearrange
import math# [KANT.py 中的所有类定义 (LayerNorm, GELU, KANAttention, FFN2, TransformerBlock 等) 必须粘贴在此处]
# ... (假设所有必要的类定义,如 LayerNorm, GELU, KolmogorovArnoldNetwork, KANAttention, FFN2, TransformerBlock 都已在此处定义) ...# --- KANT.py 内容开始 (为独立执行而复制) ---
# 工具函数
def to_3d(x):return rearrange(x, 'b c h w -> b (h w) c')def to_4d(x,h,w):return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)class BiasFree_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(BiasFree_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):sigma = x.var(-1, keepdim=True, unbiased=False)return x / torch.sqrt(sigma+1e-5) * self.weightclass WithBias_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(WithBias_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):mu = x.mean(-1, keepdim=True)sigma = x.var(-1, keepdim=True, unbiased=False)return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.biasclass LayerNorm(nn.Module):def __init__(self, dim, LayerNorm_type):super(LayerNorm, self).__init__()if LayerNorm_type =='BiasFree':self.body = BiasFree_LayerNorm(dim)else:self.body = WithBias_LayerNorm(dim)def forward(self, x):h, w = x.shape[-2:]return to_4d(self.body(to_3d(x)), h, w)class GELU(nn.Module):def forward(self, x):return F.gelu(x)# KolmogorovArnoldNetwork (基于 MLP 的 KAN)
class KolmogorovArnoldNetwork(nn.Module):def __init__(self, input_channels, hidden_size=256):super(KolmogorovArnoldNetwork, self).__init__()self.input_channels = input_channelsself.hidden_size = hidden_sizeself.fc1_list = nn.ModuleList([nn.Linear(1, hidden_size) for _ in range(input_channels)])self.fc2_list = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(input_channels)])self.fc3_list = nn.ModuleList([nn.Linear(hidden_size, 3) for _ in range(input_channels)]) # 输出3个用于Q,K,V部分self.relu = nn.ReLU()def forward(self, x): # 期望输入 x 形状: (batch_size, H, W, C)batch_size, H, W, C = x.shapex_reshaped = x.reshape(-1, C)outputs_mlp = []for i in range(self.input_channels):xi = x_reshaped[:, i:i+1]xi = self.relu(self.fc1_list[i](xi))xi = self.relu(self.fc2_list[i](xi))xi = self.fc3_list[i](xi)outputs_mlp.append(xi)x_cat = torch.cat(outputs_mlp, dim=1)x_final = x_cat.view(batch_size, H, W, C*3)return x_final# KANAttention
class KANAttention(nn.Module):def __init__(self, dim, num_heads, bias=True):super(KANAttention, self).__init__()self.num_heads = num_headsself.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))self.proj_in = KolmogorovArnoldNetwork(input_channels=dim, hidden_size=dim)self.proj_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def apply_kan(self, kan_layer, x_in_bcwh):x_permuted_for_kan = x_in_bcwh.permute(0, 2, 3, 1).contiguous()kan_output_bhwc = kan_layer(x_permuted_for_kan)x_out_bcwh = kan_output_bhwc.permute(0, 3, 1, 2).contiguous()return x_out_bcwhdef forward(self, x):b,c,h,w = x.shapeqkv = self.apply_kan(self.proj_in, x)q,k,v = qkv.chunk(3, dim=1)q = rearrange(q, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)k = rearrange(k, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)v = rearrange(v, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperatureattn = attn.softmax(dim=-1)out = (attn @ v)out = rearrange(out, 'b head c_head (h w) -> b (head c_head) h w', head=self.num_heads, h=h, w=w)out = self.proj_out(out)return out# FFN2
class FFN2(nn.Module):def __init__(self, dim, mult=4):super().__init__()self.net = nn.Sequential(nn.Conv2d(dim, dim * mult, 1, 1, bias=False), GELU(),nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult), GELU(),nn.Conv2d(dim * mult, dim, 1, 1, bias=False),)def forward(self, x):return self.net(x)# TransformerBlock
class TransformerBlock(nn.Module):def __init__(self, in_channels, num_heads, num_experts, dim_feedforward=None, dropout=0.1, LayerNorm_type='WithBias'):super(TransformerBlock, self).__init__()self.attention = KANAttention(dim=in_channels, num_heads=num_heads)self.norm1 = LayerNorm(dim=in_channels, LayerNorm_type=LayerNorm_type)self.moe = FFN2(dim=in_channels) # 使用 FFN2self.norm2 = LayerNorm(dim=in_channels, LayerNorm_type=LayerNorm_type)def forward(self, x):f_in_normed_for_attn = self.norm1(x)attended_features = self.attention(f_in_normed_for_attn)x = x + attended_featuresf_hat_normed_for_ffn = self.norm2(x)ffn_features = self.moe(f_hat_normed_for_ffn)x = x + ffn_featuresreturn x# KANT 模型 (带有形状打印功能)
class KANT_ShapeTracer(nn.Module):def __init__(self, in_channels=3, out_channels=3, n_feat=31): # 为追踪简化参数super(KANT_ShapeTracer, self).__init__()print(f"--- KANT 模型初始化 ---")print(f"输入通道数: {in_channels}, 输出通道数: {out_channels}, 基础特征数 (n_feat): {n_feat}\n")num_heads_start = 2num_experts = None # 未使用,因为直接使用 FFN2self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=1, padding='same')print(f"  conv_in: Conv2d({in_channels}, {n_feat}, kernel_size=1)")# 第 1 层编码器current_heads_l1 = num_heads_startself.transformer_block1_1 = TransformerBlock(n_feat, current_heads_l1, num_experts)print(f"  transformer_block1_1: TransformerBlock(n_feat={n_feat}, heads={current_heads_l1})")self.downsample1 = nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=2, padding=1)print(f"  downsample1: Conv2d({n_feat}, {n_feat*2}, kernel_size=3, stride=2)")# 第 2 层编码器current_heads_l2 = num_heads_start # 在瓶颈层调整前,与l1保持一致 (根据原始代码)self.transformer_block2_1 = TransformerBlock(n_feat * 2, current_heads_l2, num_experts)print(f"  transformer_block2_1: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_l2})")self.transformer_block2_2 = TransformerBlock(n_feat * 2, current_heads_l2, num_experts)print(f"  transformer_block2_2: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_l2})")self.downsample2 = nn.Conv2d(n_feat * 2, n_feat * 4, kernel_size=3, stride=2, padding=1)print(f"  downsample2: Conv2d({n_feat*2}, {n_feat*4}, kernel_size=3, stride=2)")# 瓶颈层current_heads_bn = current_heads_l2 * 2 # 瓶颈层头数加倍self.bottleneck_1 = TransformerBlock(n_feat * 4, current_heads_bn, num_experts)print(f"  bottleneck_1: TransformerBlock(n_feat={n_feat*4}, heads={current_heads_bn})")self.bottleneck_2 = TransformerBlock(n_feat * 4, current_heads_bn, num_experts)print(f"  bottleneck_2: TransformerBlock(n_feat={n_feat*4}, heads={current_heads_bn})")# 第 2 层解码器current_heads_up2 = current_heads_bn // 2self.upsample2 = nn.ConvTranspose2d(n_feat * 4, n_feat * 2, kernel_size=3, stride=2, padding=1, output_padding=1)print(f"  upsample2: ConvTranspose2d({n_feat*4}, {n_feat*2}, kernel_size=3, stride=2)")self.channel_adjust2 = nn.Conv2d(n_feat * 4, n_feat * 2, kernel_size=1) # 输入是 n_feat*2 (上采样) + n_feat*2 (跳跃) = n_feat*4print(f"  channel_adjust2: Conv2d({n_feat*4}, {n_feat*2}, kernel_size=1)")self.transformer_block_up2_1 = TransformerBlock(n_feat * 2, current_heads_up2, num_experts)print(f"  transformer_block_up2_1: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_up2})")self.transformer_block_up2_2 = TransformerBlock(n_feat * 2, current_heads_up2, num_experts)print(f"  transformer_block_up2_2: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_up2})")# 第 1 层解码器current_heads_up1 = current_heads_up2 // 2self.upsample1 = nn.ConvTranspose2d(n_feat * 2, n_feat, kernel_size=3, stride=2, padding=1, output_padding=1)print(f"  upsample1: ConvTranspose2d({n_feat*2}, {n_feat}, kernel_size=3, stride=2)")self.channel_adjust1 = nn.Conv2d(n_feat * 2, n_feat, kernel_size=1) # 输入是 n_feat (上采样) + n_feat (跳跃) = n_feat*2print(f"  channel_adjust1: Conv2d({n_feat*2}, {n_feat}, kernel_size=1)")self.transformer_block_up1_1 = TransformerBlock(n_feat, current_heads_up1, num_experts)print(f"  transformer_block_up1_1: TransformerBlock(n_feat={n_feat}, heads={current_heads_up1})")self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=1, padding='same')print(f"  conv_out: Conv2d({n_feat}, {out_channels}, kernel_size=1)")print(f"--- KANT 模型初始化结束 ---\n")def forward(self, x):print(f"\n--- KANT 前向传播形状追踪 ---")print(f"初始输入形状: {x.shape}")x = self.conv_in(x)print(f"经过 conv_in 后: {x.shape}")# 编码器路径x1 = self.transformer_block1_1(x)print(f"经过 transformer_block1_1 (x1) 后: {x1.shape}")x1_down = self.downsample1(x1)print(f"经过 downsample1 (x1_down) 后: {x1_down.shape}")x2 = self.transformer_block2_1(x1_down)print(f"经过 transformer_block2_1 后: {x2.shape}")x2 = self.transformer_block2_2(x2)print(f"经过 transformer_block2_2 (x2) 后: {x2.shape}")x2_down = self.downsample2(x2)print(f"经过 downsample2 (x2_down) 后: {x2_down.shape}")# 瓶颈层bn = self.bottleneck_1(x2_down)print(f"经过 bottleneck_1 后: {bn.shape}")bn = self.bottleneck_2(bn)print(f"经过 bottleneck_2 (bn) 后: {bn.shape}")# 解码器路径x2_up_pre_cat = self.upsample2(bn)print(f"经过 upsample2 (x2_up_pre_cat) 后: {x2_up_pre_cat.shape}")x2_up = torch.cat([x2_up_pre_cat, x2], dim=1)print(f"经过 cat([x2_up_pre_cat, x2]) 后: {x2_up.shape}")x2_up = self.channel_adjust2(x2_up)print(f"经过 channel_adjust2 (x2_up) 后: {x2_up.shape}")x2_up = self.transformer_block_up2_1(x2_up)print(f"经过 transformer_block_up2_1 后: {x2_up.shape}")x2_up = self.transformer_block_up2_2(x2_up)print(f"经过 transformer_block_up2_2 (x2_up) 后: {x2_up.shape}")x1_up_pre_cat = self.upsample1(x2_up)print(f"经过 upsample1 (x1_up_pre_cat) 后: {x1_up_pre_cat.shape}")x1_up = torch.cat([x1_up_pre_cat, x1], dim=1)print(f"经过 cat([x1_up_pre_cat, x1]) 后: {x1_up.shape}")x1_up = self.channel_adjust1(x1_up)print(f"经过 channel_adjust1 (x1_up) 后: {x1_up.shape}")x1_up = self.transformer_block_up1_1(x1_up)print(f"经过 transformer_block_up1_1 (x1_up) 后: {x1_up.shape}")x_out = self.conv_out(x1_up)print(f"经过 conv_out (最终输出) 后: {x_out.shape}")print(f"--- KANT 前向传播形状追踪结束 ---\n")return x_out# --- KANT.py 内容结束 ---if __name__ == '__main__':# 用于测试的示例参数batch_size = 1input_channels = 3height, width = 256, 256 # 示例图像尺寸n_features = 31 # 基础特征数量,与 KANT 类中一致# 创建一个虚拟输入张量dummy_input = torch.randn(batch_size, input_channels, height, width)print(f"正在创建 KANT_ShapeTracer 模型,n_feat={n_features}...")# 实例化模型model = KANT_ShapeTracer(in_channels=input_channels, out_channels=input_channels, n_feat=n_features)# 执行一次前向传播以追踪形状print(f"使用形状为 {dummy_input.shape} 的虚拟输入执行前向传播")with torch.no_grad(): # 追踪形状时无需计算梯度output = model(dummy_input)print(f"最终输出张量形状: {output.shape}")# 你也可以打印模型结构# print("\n模型结构:")# print(model)

如何运行形状追踪脚本:

  1. 将上述代码保存为一个 Python 文件 (例如, trace_kant_shape_cn.py)。
  2. 从终端运行它: python trace_kant_shape_cn.py

这将打印 KANT_ShapeTracer 模型在初始化时的配置,然后在虚拟输入通过 forward 方法中的每个重要层/操作时追踪张量形状。n_feat 参数设置为 31,与你的 KANT 类默认值一致。你可以更改 heightwidth 来测试不同的输入分辨率。

http://www.xdnf.cn/news/627733.html

相关文章:

  • Java——设计模式(Design Pattern)
  • DAY 35
  • Shell三剑客之awk
  • 全球化 2.0 | 云轴科技ZStack助力中东智慧城市高性能智能安防云平台
  • TypeScript小技巧使用as const:让类型推断更精准。
  • exti line2 interrupt 如何写中断回调
  • 数据库中表的设计规范
  • 【软考】【信息系统项目管理师】2025年5月24日考试回忆版,祝明天考试的兄弟们顺利
  • maxkey单点登录系统
  • Neo4j(二) - 使用Cypher操作Neo4j
  • iOS 直播特殊礼物特效实现方案(Swift实现,超详细!)
  • STM32F446主时钟失效时DAC输出异常现象解析与解决方案
  • AtCoder AT_abc407_d [ABC407D] Domino Covering XOR
  • 【Web前端】jQuery入门与基础(二)
  • 免费PDF工具-PDF24V9.16.0【win7专用版】
  • TypeScript基础数据类型详解总结
  • 常见的图像生成模型
  • 嵌入式开发学习日志(linux系统编程--进程(1))Day27
  • winsever2016Web服务器平台安装与配置
  • python训练营day34
  • TIT-2014《Randomized Dimensionality Reduction for $k$-means Clustering》
  • 第十天的尝试
  • 快速排序算法的C++和C语言对比
  • Python实用工具:文件批量重命名器
  • Unity3D仿星露谷物语开发49之创建云杉树
  • 常见算法题目3 -反转字符串
  • 2025年—ComfyUI_最新插件推荐及使用(实时更新)
  • 保姆式一步一步制作B端左侧菜单栏
  • 游园安排--最长上升子序列+输出序列
  • 力扣:《螺旋矩阵》系列题目