当前位置: 首页 > ai >正文

第J9周:Inception v3算法实战与解析

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

实现Inception V3模型, 通过模型训练来识别天气情况。

概念

Inception V3是Google在2015年提出的深度卷积神经网络架构,是Inception系列的重要改进版本。
Inception模块优化 Inception V3延续了V1的"在同一层使用不同大小卷积核"的思想,但进行了重要改进。通过并行使用1×1、3×3、5×5卷积和3×3最大池化,然后将结果拼接,让网络自动学习最适合的特征组合。

关键技术创新

1. 卷积分解(Factorization)

  • 将大卷积核分解为小卷积核:7×7分解为多个3×3
  • 将n×n卷积分解为1×n和n×1的非对称卷积,如3×3分解为1×3和3×1
  • 显著减少参数量和计算量,同时增加网络深度

2. 高效的降维策略

  • 使用1×1卷积进行通道降维,减少计算负担
  • 在expensive操作前进行降维,在cheap操作后进行升维
  • 通过智能的特征图尺寸缩减避免表征瓶颈

3. 辅助分类器改进

  • 在网络中间添加辅助分类器缓解梯度消失
  • 使用批归一化(Batch Normalization)替代dropout
  • 辅助损失权重设为0.3,避免过度影响主要学习目标

网络架构特点

模块化设计

  • 使用三种不同类型的Inception模块(A、B、C型)
  • 每种模块针对不同分辨率的特征图优化
  • 模块间使用Grid Size Reduction技术有效降维

批归一化集成

  • 在几乎所有卷积层后添加Batch Normalization
  • 显著加速训练过程,提高模型稳定性
  • 减少对初始化的敏感性

性能优化

计算效率

  • 相比Inception V1,参数量减少约2.5倍
  • 计算复杂度大幅降低,推理速度更快
  • 内存使用更加高效

精度提升

  • 在ImageNet数据集上取得了当时最优的分类精度
  • Top-5错误率降至约5.6%
  • 模型泛化能力强

设计原则

1. 避免表征瓶颈 特别是在网络早期,避免极端压缩导致信息丢失。

2. 高维表征更易处理 在卷积网络中,增加每个tile的激活数量可以提高收敛速度。

3. 空间聚合可在低维嵌入上进行 在进行3×3或5×5卷积前,可以在输入上进行降维而不会造成严重后果。

4. 平衡网络宽度和深度 最优网络性能可以通过平衡每个阶段的滤波器数量来实现。

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: pytorch

(二)具体步骤
1. Inceptionv3.py
  
import torch  
import torch.nn.functional as F  
from torch import nn  class InceptionA(nn.Module):  def __init__(self, in_channels, pool_features):  super(InceptionA, self).__init__()  self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) # 1  self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)  self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)  self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)  self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)  self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)  self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)  def forward(self, x):  branch1x1 = self.branch1x1(x)  branch5x5 = self.branch5x5_1(x)  branch5x5 = self.branch5x5_2(branch5x5)  branch3x3dbl = self.branch3x3dbl_1(x)  branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)  branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)  branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)  branch_pool = self.branch_pool(branch_pool)  outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]  return torch.cat(outputs, 1)  class InceptionB(nn.Module):  def __init__(self, in_channels, channels_7x7):  super(InceptionB, self).__init__()  self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)  c7 = channels_7x7  self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)  self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))  self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))  self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)  self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))  self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))  self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))  self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))  self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)  def forward(self, x):  branch1x1 = self.branch1x1(x)  branch7x7 = self.branch7x7_1(x)  branch7x7 = self.branch7x7_2(branch7x7)  branch7x7 = self.branch7x7_3(branch7x7)  branch7x7dbl = self.branch7x7dbl_1(x)  branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)  branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)  branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)  branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)  branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)  branch_pool = self.branch_pool(branch_pool)  outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]  return torch.cat(outputs, 1)  class InceptionC(nn.Module):  def __init__(self, in_channels):  super(InceptionC, self).__init__()  self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)  self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)  self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))  self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))  self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)  self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)  self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))  self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))  self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)  def forward(self, x):  branch1x1 = self.branch1x1(x)  branch3x3 = self.branch3x3_1(x)  branch3x3 = [  self.branch3x3_2a(branch3x3),  self.branch3x3_2b(branch3x3),  ]  branch3x3 = torch.cat(branch3x3, 1)  branch3x3dbl = self.branch3x3dbl_1(x)  branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)  branch3x3dbl = [  self.branch3x3dbl_3a(branch3x3dbl),  self.branch3x3dbl_3b(branch3x3dbl),  ]  branch3x3dbl = torch.cat(branch3x3dbl, 1)  branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)  branch_pool = self.branch_pool(branch_pool)  outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]  return torch.cat(outputs, 1)  class ReductionA(nn.Module):  def __init__(self, in_channels):  super(ReductionA, self).__init__()  self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)  self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)  self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)  self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)  def forward(self, x):  branch3x3 = self.branch3x3(x)  branch3x3dbl = self.branch3x3dbl_1(x)  branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)  branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)  branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)  outputs = [branch3x3, branch3x3dbl, branch_pool]  return torch.cat(outputs, 1)  class ReductionB(nn.Module):  def __init__(self, in_channels):  super(ReductionB, self).__init__()  self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)  self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)  self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)  self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))  self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))  self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)  def forward(self, x):  branch3x3 = self.branch3x3_1(x)  branch3x3 = self.branch3x3_2(branch3x3)  branch7x7x3 = self.branch7x7x3_1(x)  branch7x7x3 = self.branch7x7x3_2(branch7x7x3)  branch7x7x3 = self.branch7x7x3_3(branch7x7x3)  branch7x7x3 = self.branch7x7x3_4(branch7x7x3)  branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)  outputs = [branch3x3, branch7x7x3, branch_pool]  return torch.cat(outputs, 1)  class InceptionAux(nn.Module):  def __init__(self, in_channels, num_classes):  super(InceptionAux, self).__init__()  self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)  self.conv1 = BasicConv2d(128, 768, kernel_size=5)  self.conv1.stddev = 0.01  self.fc = nn.Linear(768, num_classes)  self.fc.stddev = 0.001  def forward(self, x):  # 17 x 17 x 768  x = F.avg_pool2d(x, kernel_size=5, stride=3)  # 5 x 5 x 768  x = self.conv0(x)  # 5 x 5 x 128  x = self.conv1(x)  # 1 x 1 x 768  x = x.view(x.size(0), -1)  # 768  x = self.fc(x)  # 1000  return x  class BasicConv2d(nn.Module):  def __init__(self, in_channels, out_channels, **kwargs):  super(BasicConv2d, self).__init__()  self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)  self.bn = nn.BatchNorm2d(out_channels, eps=0.001)  def forward(self, x):  x = self.conv(x)  x = self.bn(x)  return F.relu(x, inplace=True)  class InceptionV3(nn.Module):  def __init__(self, num_classes=1000, aux_logits=False, transform_input=False):  super(InceptionV3, self).__init__()  self.aux_logits = aux_logits  self.transform_input = transform_input  self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)  self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)  self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)  self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)  self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)  self.Mixed_5b = InceptionA(192, pool_features=32)  self.Mixed_5c = InceptionA(256, pool_features=64)  self.Mixed_5d = InceptionA(288, pool_features=64)  self.Mixed_6a = ReductionA(288)  self.Mixed_6b = InceptionB(768, channels_7x7=128)  self.Mixed_6c = InceptionB(768, channels_7x7=160)  self.Mixed_6d = InceptionB(768, channels_7x7=160)  self.Mixed_6e = InceptionB(768, channels_7x7=192)  if aux_logits:  self.AuxLogits = InceptionAux(768, num_classes)  self.Mixed_7a = ReductionB(768)  self.Mixed_7b = InceptionC(1280)  self.Mixed_7c = InceptionC(2048)  self.fc = nn.Linear(2048, num_classes)  def forward(self, x):  if self.transform_input:  # 1  x = x.clone()  x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5  x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5  x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5  # 299 x 299 x 3  x = self.Conv2d_1a_3x3(x)  # 149 x 149 x 32  x = self.Conv2d_2a_3x3(x)  # 147 x 147 x 32  x = self.Conv2d_2b_3x3(x)  # 147 x 147 x 64  x = F.max_pool2d(x, kernel_size=3, stride=2)  # 73 x 73 x 64  x = self.Conv2d_3b_1x1(x)  # 73 x 73 x 80  x = self.Conv2d_4a_3x3(x)  # 71 x 71 x 192  x = F.max_pool2d(x, kernel_size=3, stride=2)  # 35 x 35 x 192  x = self.Mixed_5b(x)  # 35 x 35 x 256  x = self.Mixed_5c(x)  # 35 x 35 x 288  x = self.Mixed_5d(x)  # 35 x 35 x 288  x = self.Mixed_6a(x)  # 17 x 17 x 768  x = self.Mixed_6b(x)  # 17 x 17 x 768  x = self.Mixed_6c(x)  # 17 x 17 x 768  x = self.Mixed_6d(x)  # 17 x 17 x 768  x = self.Mixed_6e(x)  # 17 x 17 x 768  if self.training and self.aux_logits:  aux = self.AuxLogits(x)  # 17 x 17 x 768  x = self.Mixed_7a(x)  # 8 x 8 x 1280  x = self.Mixed_7b(x)  # 8 x 8 x 2048  x = self.Mixed_7c(x)  # 8 x 8 x 2048  x = F.avg_pool2d(x, kernel_size=8)  # 1 x 1 x 2048  x = F.dropout(x, training=self.training)  # 1 x 1 x 2048  x = x.view(x.size(0), -1)  # 2048  x = self.fc(x)  # 1000 (num_classes)  if self.training and self.aux_logits:  return x, aux  return x  if __name__=='__main__':  device = "cuda" if torch.cuda.is_available() else "cpu"  print("Using {} device".format(device))  model = InceptionV3().to(device)  print(model)  # 统计模型参数量以及其他指标  import torchsummary as summary  summary.summary(model, (3, 299, 299))
Using cuda device
InceptionV3((Conv2d_1a_3x3): BasicConv2d((conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(Conv2d_2a_3x3): BasicConv2d((conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(Conv2d_2b_3x3): BasicConv2d((conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(Conv2d_3b_1x1): BasicConv2d((conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(Conv2d_4a_3x3): BasicConv2d((conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(Mixed_5b): InceptionA((branch1x1): BasicConv2d((conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch5x5_1): BasicConv2d((conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch5x5_2): BasicConv2d((conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_5c): InceptionA((branch1x1): BasicConv2d((conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch5x5_1): BasicConv2d((conv): Conv2d(256, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch5x5_2): BasicConv2d((conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_5d): InceptionA((branch1x1): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch5x5_1): BasicConv2d((conv): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch5x5_2): BasicConv2d((conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_6a): ReductionA((branch3x3): BasicConv2d((conv): Conv2d(288, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3): BasicConv2d((conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_6b): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_2): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_3): BasicConv2d((conv): Conv2d(128, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(128, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_6c): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_3): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_6d): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_3): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(160, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_6e): InceptionB((branch1x1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_2): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7_3): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_2): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_3): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_4): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7dbl_5): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_7a): ReductionB((branch3x3_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_2): BasicConv2d((conv): Conv2d(192, 320, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7x3_1): BasicConv2d((conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7x3_2): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7x3_3): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch7x7x3_4): BasicConv2d((conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_7b): InceptionC((branch1x1): BasicConv2d((conv): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_1): BasicConv2d((conv): Conv2d(1280, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_2a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_2b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(1280, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(1280, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(Mixed_7c): InceptionC((branch1x1): BasicConv2d((conv): Conv2d(2048, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_1): BasicConv2d((conv): Conv2d(2048, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_2a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3_2b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_1): BasicConv2d((conv): Conv2d(2048, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_2): BasicConv2d((conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3a): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch3x3dbl_3b): BasicConv2d((conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)(bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))(branch_pool): BasicConv2d((conv): Conv2d(2048, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)))(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 32, 149, 149]             864BatchNorm2d-2         [-1, 32, 149, 149]              64BasicConv2d-3         [-1, 32, 149, 149]               0Conv2d-4         [-1, 32, 147, 147]           9,216BatchNorm2d-5         [-1, 32, 147, 147]              64BasicConv2d-6         [-1, 32, 147, 147]               0Conv2d-7         [-1, 64, 147, 147]          18,432BatchNorm2d-8         [-1, 64, 147, 147]             128BasicConv2d-9         [-1, 64, 147, 147]               0Conv2d-10           [-1, 80, 73, 73]           5,120BatchNorm2d-11           [-1, 80, 73, 73]             160BasicConv2d-12           [-1, 80, 73, 73]               0Conv2d-13          [-1, 192, 71, 71]         138,240BatchNorm2d-14          [-1, 192, 71, 71]             384BasicConv2d-15          [-1, 192, 71, 71]               0Conv2d-16           [-1, 64, 35, 35]          12,288BatchNorm2d-17           [-1, 64, 35, 35]             128BasicConv2d-18           [-1, 64, 35, 35]               0Conv2d-19           [-1, 48, 35, 35]           9,216BatchNorm2d-20           [-1, 48, 35, 35]              96BasicConv2d-21           [-1, 48, 35, 35]               0Conv2d-22           [-1, 64, 35, 35]          76,800BatchNorm2d-23           [-1, 64, 35, 35]             128BasicConv2d-24           [-1, 64, 35, 35]               0Conv2d-25           [-1, 64, 35, 35]          12,288BatchNorm2d-26           [-1, 64, 35, 35]             128BasicConv2d-27           [-1, 64, 35, 35]               0Conv2d-28           [-1, 96, 35, 35]          55,296BatchNorm2d-29           [-1, 96, 35, 35]             192BasicConv2d-30           [-1, 96, 35, 35]               0Conv2d-31           [-1, 96, 35, 35]          82,944BatchNorm2d-32           [-1, 96, 35, 35]             192BasicConv2d-33           [-1, 96, 35, 35]               0Conv2d-34           [-1, 32, 35, 35]           6,144BatchNorm2d-35           [-1, 32, 35, 35]              64BasicConv2d-36           [-1, 32, 35, 35]               0InceptionA-37          [-1, 256, 35, 35]               0Conv2d-38           [-1, 64, 35, 35]          16,384BatchNorm2d-39           [-1, 64, 35, 35]             128BasicConv2d-40           [-1, 64, 35, 35]               0Conv2d-41           [-1, 48, 35, 35]          12,288BatchNorm2d-42           [-1, 48, 35, 35]              96BasicConv2d-43           [-1, 48, 35, 35]               0Conv2d-44           [-1, 64, 35, 35]          76,800BatchNorm2d-45           [-1, 64, 35, 35]             128BasicConv2d-46           [-1, 64, 35, 35]               0Conv2d-47           [-1, 64, 35, 35]          16,384BatchNorm2d-48           [-1, 64, 35, 35]             128BasicConv2d-49           [-1, 64, 35, 35]               0Conv2d-50           [-1, 96, 35, 35]          55,296BatchNorm2d-51           [-1, 96, 35, 35]             192BasicConv2d-52           [-1, 96, 35, 35]               0Conv2d-53           [-1, 96, 35, 35]          82,944BatchNorm2d-54           [-1, 96, 35, 35]             192BasicConv2d-55           [-1, 96, 35, 35]               0Conv2d-56           [-1, 64, 35, 35]          16,384BatchNorm2d-57           [-1, 64, 35, 35]             128BasicConv2d-58           [-1, 64, 35, 35]               0InceptionA-59          [-1, 288, 35, 35]               0Conv2d-60           [-1, 64, 35, 35]          18,432BatchNorm2d-61           [-1, 64, 35, 35]             128BasicConv2d-62           [-1, 64, 35, 35]               0Conv2d-63           [-1, 48, 35, 35]          13,824BatchNorm2d-64           [-1, 48, 35, 35]              96BasicConv2d-65           [-1, 48, 35, 35]               0Conv2d-66           [-1, 64, 35, 35]          76,800BatchNorm2d-67           [-1, 64, 35, 35]             128BasicConv2d-68           [-1, 64, 35, 35]               0Conv2d-69           [-1, 64, 35, 35]          18,432BatchNorm2d-70           [-1, 64, 35, 35]             128BasicConv2d-71           [-1, 64, 35, 35]               0Conv2d-72           [-1, 96, 35, 35]          55,296BatchNorm2d-73           [-1, 96, 35, 35]             192BasicConv2d-74           [-1, 96, 35, 35]               0Conv2d-75           [-1, 96, 35, 35]          82,944BatchNorm2d-76           [-1, 96, 35, 35]             192BasicConv2d-77           [-1, 96, 35, 35]               0Conv2d-78           [-1, 64, 35, 35]          18,432BatchNorm2d-79           [-1, 64, 35, 35]             128BasicConv2d-80           [-1, 64, 35, 35]               0InceptionA-81          [-1, 288, 35, 35]               0Conv2d-82          [-1, 384, 17, 17]         995,328BatchNorm2d-83          [-1, 384, 17, 17]             768BasicConv2d-84          [-1, 384, 17, 17]               0Conv2d-85           [-1, 64, 35, 35]          18,432BatchNorm2d-86           [-1, 64, 35, 35]             128BasicConv2d-87           [-1, 64, 35, 35]               0Conv2d-88           [-1, 96, 35, 35]          55,296BatchNorm2d-89           [-1, 96, 35, 35]             192BasicConv2d-90           [-1, 96, 35, 35]               0Conv2d-91           [-1, 96, 17, 17]          82,944BatchNorm2d-92           [-1, 96, 17, 17]             192BasicConv2d-93           [-1, 96, 17, 17]               0ReductionA-94          [-1, 768, 17, 17]               0Conv2d-95          [-1, 192, 17, 17]         147,456BatchNorm2d-96          [-1, 192, 17, 17]             384BasicConv2d-97          [-1, 192, 17, 17]               0Conv2d-98          [-1, 128, 17, 17]          98,304BatchNorm2d-99          [-1, 128, 17, 17]             256BasicConv2d-100          [-1, 128, 17, 17]               0Conv2d-101          [-1, 128, 17, 17]         114,688BatchNorm2d-102          [-1, 128, 17, 17]             256BasicConv2d-103          [-1, 128, 17, 17]               0Conv2d-104          [-1, 192, 17, 17]         172,032BatchNorm2d-105          [-1, 192, 17, 17]             384BasicConv2d-106          [-1, 192, 17, 17]               0Conv2d-107          [-1, 128, 17, 17]          98,304BatchNorm2d-108          [-1, 128, 17, 17]             256BasicConv2d-109          [-1, 128, 17, 17]               0Conv2d-110          [-1, 128, 17, 17]         114,688BatchNorm2d-111          [-1, 128, 17, 17]             256BasicConv2d-112          [-1, 128, 17, 17]               0Conv2d-113          [-1, 128, 17, 17]         114,688BatchNorm2d-114          [-1, 128, 17, 17]             256BasicConv2d-115          [-1, 128, 17, 17]               0Conv2d-116          [-1, 128, 17, 17]         114,688BatchNorm2d-117          [-1, 128, 17, 17]             256BasicConv2d-118          [-1, 128, 17, 17]               0Conv2d-119          [-1, 192, 17, 17]         172,032BatchNorm2d-120          [-1, 192, 17, 17]             384BasicConv2d-121          [-1, 192, 17, 17]               0Conv2d-122          [-1, 192, 17, 17]         147,456BatchNorm2d-123          [-1, 192, 17, 17]             384BasicConv2d-124          [-1, 192, 17, 17]               0InceptionB-125          [-1, 768, 17, 17]               0Conv2d-126          [-1, 192, 17, 17]         147,456BatchNorm2d-127          [-1, 192, 17, 17]             384BasicConv2d-128          [-1, 192, 17, 17]               0Conv2d-129          [-1, 160, 17, 17]         122,880BatchNorm2d-130          [-1, 160, 17, 17]             320BasicConv2d-131          [-1, 160, 17, 17]               0Conv2d-132          [-1, 160, 17, 17]         179,200BatchNorm2d-133          [-1, 160, 17, 17]             320BasicConv2d-134          [-1, 160, 17, 17]               0Conv2d-135          [-1, 192, 17, 17]         215,040BatchNorm2d-136          [-1, 192, 17, 17]             384BasicConv2d-137          [-1, 192, 17, 17]               0Conv2d-138          [-1, 160, 17, 17]         122,880BatchNorm2d-139          [-1, 160, 17, 17]             320BasicConv2d-140          [-1, 160, 17, 17]               0Conv2d-141          [-1, 160, 17, 17]         179,200BatchNorm2d-142          [-1, 160, 17, 17]             320BasicConv2d-143          [-1, 160, 17, 17]               0Conv2d-144          [-1, 160, 17, 17]         179,200BatchNorm2d-145          [-1, 160, 17, 17]             320BasicConv2d-146          [-1, 160, 17, 17]               0Conv2d-147          [-1, 160, 17, 17]         179,200BatchNorm2d-148          [-1, 160, 17, 17]             320BasicConv2d-149          [-1, 160, 17, 17]               0Conv2d-150          [-1, 192, 17, 17]         215,040BatchNorm2d-151          [-1, 192, 17, 17]             384BasicConv2d-152          [-1, 192, 17, 17]               0Conv2d-153          [-1, 192, 17, 17]         147,456BatchNorm2d-154          [-1, 192, 17, 17]             384BasicConv2d-155          [-1, 192, 17, 17]               0InceptionB-156          [-1, 768, 17, 17]               0Conv2d-157          [-1, 192, 17, 17]         147,456BatchNorm2d-158          [-1, 192, 17, 17]             384BasicConv2d-159          [-1, 192, 17, 17]               0Conv2d-160          [-1, 160, 17, 17]         122,880BatchNorm2d-161          [-1, 160, 17, 17]             320BasicConv2d-162          [-1, 160, 17, 17]               0Conv2d-163          [-1, 160, 17, 17]         179,200BatchNorm2d-164          [-1, 160, 17, 17]             320BasicConv2d-165          [-1, 160, 17, 17]               0Conv2d-166          [-1, 192, 17, 17]         215,040BatchNorm2d-167          [-1, 192, 17, 17]             384BasicConv2d-168          [-1, 192, 17, 17]               0Conv2d-169          [-1, 160, 17, 17]         122,880BatchNorm2d-170          [-1, 160, 17, 17]             320BasicConv2d-171          [-1, 160, 17, 17]               0Conv2d-172          [-1, 160, 17, 17]         179,200BatchNorm2d-173          [-1, 160, 17, 17]             320BasicConv2d-174          [-1, 160, 17, 17]               0Conv2d-175          [-1, 160, 17, 17]         179,200BatchNorm2d-176          [-1, 160, 17, 17]             320BasicConv2d-177          [-1, 160, 17, 17]               0Conv2d-178          [-1, 160, 17, 17]         179,200BatchNorm2d-179          [-1, 160, 17, 17]             320BasicConv2d-180          [-1, 160, 17, 17]               0Conv2d-181          [-1, 192, 17, 17]         215,040BatchNorm2d-182          [-1, 192, 17, 17]             384BasicConv2d-183          [-1, 192, 17, 17]               0Conv2d-184          [-1, 192, 17, 17]         147,456BatchNorm2d-185          [-1, 192, 17, 17]             384BasicConv2d-186          [-1, 192, 17, 17]               0InceptionB-187          [-1, 768, 17, 17]               0Conv2d-188          [-1, 192, 17, 17]         147,456BatchNorm2d-189          [-1, 192, 17, 17]             384BasicConv2d-190          [-1, 192, 17, 17]               0Conv2d-191          [-1, 192, 17, 17]         147,456BatchNorm2d-192          [-1, 192, 17, 17]             384BasicConv2d-193          [-1, 192, 17, 17]               0Conv2d-194          [-1, 192, 17, 17]         258,048BatchNorm2d-195          [-1, 192, 17, 17]             384BasicConv2d-196          [-1, 192, 17, 17]               0Conv2d-197          [-1, 192, 17, 17]         258,048BatchNorm2d-198          [-1, 192, 17, 17]             384BasicConv2d-199          [-1, 192, 17, 17]               0Conv2d-200          [-1, 192, 17, 17]         147,456BatchNorm2d-201          [-1, 192, 17, 17]             384BasicConv2d-202          [-1, 192, 17, 17]               0Conv2d-203          [-1, 192, 17, 17]         258,048BatchNorm2d-204          [-1, 192, 17, 17]             384BasicConv2d-205          [-1, 192, 17, 17]               0Conv2d-206          [-1, 192, 17, 17]         258,048BatchNorm2d-207          [-1, 192, 17, 17]             384BasicConv2d-208          [-1, 192, 17, 17]               0Conv2d-209          [-1, 192, 17, 17]         258,048BatchNorm2d-210          [-1, 192, 17, 17]             384BasicConv2d-211          [-1, 192, 17, 17]               0Conv2d-212          [-1, 192, 17, 17]         258,048BatchNorm2d-213          [-1, 192, 17, 17]             384BasicConv2d-214          [-1, 192, 17, 17]               0Conv2d-215          [-1, 192, 17, 17]         147,456BatchNorm2d-216          [-1, 192, 17, 17]             384BasicConv2d-217          [-1, 192, 17, 17]               0InceptionB-218          [-1, 768, 17, 17]               0Conv2d-219          [-1, 192, 17, 17]         147,456BatchNorm2d-220          [-1, 192, 17, 17]             384BasicConv2d-221          [-1, 192, 17, 17]               0Conv2d-222            [-1, 320, 8, 8]         552,960BatchNorm2d-223            [-1, 320, 8, 8]             640BasicConv2d-224            [-1, 320, 8, 8]               0Conv2d-225          [-1, 192, 17, 17]         147,456BatchNorm2d-226          [-1, 192, 17, 17]             384BasicConv2d-227          [-1, 192, 17, 17]               0Conv2d-228          [-1, 192, 17, 17]         258,048BatchNorm2d-229          [-1, 192, 17, 17]             384BasicConv2d-230          [-1, 192, 17, 17]               0Conv2d-231          [-1, 192, 17, 17]         258,048BatchNorm2d-232          [-1, 192, 17, 17]             384BasicConv2d-233          [-1, 192, 17, 17]               0Conv2d-234            [-1, 192, 8, 8]         331,776BatchNorm2d-235            [-1, 192, 8, 8]             384BasicConv2d-236            [-1, 192, 8, 8]               0ReductionB-237           [-1, 1280, 8, 8]               0Conv2d-238            [-1, 320, 8, 8]         409,600BatchNorm2d-239            [-1, 320, 8, 8]             640BasicConv2d-240            [-1, 320, 8, 8]               0Conv2d-241            [-1, 384, 8, 8]         491,520BatchNorm2d-242            [-1, 384, 8, 8]             768BasicConv2d-243            [-1, 384, 8, 8]               0Conv2d-244            [-1, 384, 8, 8]         442,368BatchNorm2d-245            [-1, 384, 8, 8]             768BasicConv2d-246            [-1, 384, 8, 8]               0Conv2d-247            [-1, 384, 8, 8]         442,368BatchNorm2d-248            [-1, 384, 8, 8]             768BasicConv2d-249            [-1, 384, 8, 8]               0Conv2d-250            [-1, 448, 8, 8]         573,440BatchNorm2d-251            [-1, 448, 8, 8]             896BasicConv2d-252            [-1, 448, 8, 8]               0Conv2d-253            [-1, 384, 8, 8]       1,548,288BatchNorm2d-254            [-1, 384, 8, 8]             768BasicConv2d-255            [-1, 384, 8, 8]               0Conv2d-256            [-1, 384, 8, 8]         442,368BatchNorm2d-257            [-1, 384, 8, 8]             768BasicConv2d-258            [-1, 384, 8, 8]               0Conv2d-259            [-1, 384, 8, 8]         442,368BatchNorm2d-260            [-1, 384, 8, 8]             768BasicConv2d-261            [-1, 384, 8, 8]               0Conv2d-262            [-1, 192, 8, 8]         245,760BatchNorm2d-263            [-1, 192, 8, 8]             384BasicConv2d-264            [-1, 192, 8, 8]               0InceptionC-265           [-1, 2048, 8, 8]               0Conv2d-266            [-1, 320, 8, 8]         655,360BatchNorm2d-267            [-1, 320, 8, 8]             640BasicConv2d-268            [-1, 320, 8, 8]               0Conv2d-269            [-1, 384, 8, 8]         786,432BatchNorm2d-270            [-1, 384, 8, 8]             768BasicConv2d-271            [-1, 384, 8, 8]               0Conv2d-272            [-1, 384, 8, 8]         442,368BatchNorm2d-273            [-1, 384, 8, 8]             768BasicConv2d-274            [-1, 384, 8, 8]               0Conv2d-275            [-1, 384, 8, 8]         442,368BatchNorm2d-276            [-1, 384, 8, 8]             768BasicConv2d-277            [-1, 384, 8, 8]               0Conv2d-278            [-1, 448, 8, 8]         917,504BatchNorm2d-279            [-1, 448, 8, 8]             896BasicConv2d-280            [-1, 448, 8, 8]               0Conv2d-281            [-1, 384, 8, 8]       1,548,288BatchNorm2d-282            [-1, 384, 8, 8]             768BasicConv2d-283            [-1, 384, 8, 8]               0Conv2d-284            [-1, 384, 8, 8]         442,368BatchNorm2d-285            [-1, 384, 8, 8]             768BasicConv2d-286            [-1, 384, 8, 8]               0Conv2d-287            [-1, 384, 8, 8]         442,368BatchNorm2d-288            [-1, 384, 8, 8]             768BasicConv2d-289            [-1, 384, 8, 8]               0Conv2d-290            [-1, 192, 8, 8]         393,216BatchNorm2d-291            [-1, 192, 8, 8]             384BasicConv2d-292            [-1, 192, 8, 8]               0InceptionC-293           [-1, 2048, 8, 8]               0Linear-294                 [-1, 1000]       2,049,000
================================================================
Total params: 23,834,568
Trainable params: 23,834,568
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.02
Forward/backward pass size (MB): 224.12
Params size (MB): 90.92
Estimated Total Size (MB): 316.07
----------------------------------------------------------------
2. process_dataset.py
import os  
import shutil  
import random  
from PIL import Image  
import argparse  
from tqdm import tqdm  def preprocess_dataset(input_dir, output_dir, split_ratio=0.8):  """  预处理猴痘病数据集。  参数:  - input_dir: 输入数据目录  - output_dir: 输出数据目录  - split_ratio: 训练/验证数据分割比例  """    # 检查输入目录  if not os.path.exists(input_dir):  raise ValueError(f"Input directory {input_dir} does not exist")  # 创建输出目录结构  train_dir = os.path.join(output_dir, 'train')  val_dir = os.path.join(output_dir, 'val')  # 确保输出目录存在  os.makedirs(train_dir, exist_ok=True)  os.makedirs(val_dir, exist_ok=True)  # 获取所有类别  classes = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))]  if not classes:  # 如果输入目录没有子目录,尝试按文件扩展名组织  print("No class directories found. Trying to organize by file extensions...")  try:  # 假设正常图像是 .jpg,猴痘图像是 .png (或其他区分方式)  normal_images = [f for f in os.listdir(input_dir) if f.lower().endswith('.jpg')]  monkeypox_images = [f for f in os.listdir(input_dir) if f.lower().endswith('.png')]  # 创建类别目录  os.makedirs(os.path.join(train_dir, "normal"), exist_ok=True)  os.makedirs(os.path.join(train_dir, "monkeypox"), exist_ok=True)  os.makedirs(os.path.join(val_dir, "normal"), exist_ok=True)  os.makedirs(os.path.join(val_dir, "monkeypox"), exist_ok=True)  # 处理正常图像  random.shuffle(normal_images)  split_idx = int(len(normal_images) * split_ratio)  print("Processing normal images...")  for i, img in enumerate(tqdm(normal_images)):  src = os.path.join(input_dir, img)  if i < split_idx:  dst = os.path.join(train_dir, "normal", img)  else:  dst = os.path.join(val_dir, "normal", img)  shutil.copy2(src, dst)  # 处理猴痘图像  random.shuffle(monkeypox_images)  split_idx = int(len(monkeypox_images) * split_ratio)  print("Processing monkeypox images...")  for i, img in enumerate(tqdm(monkeypox_images)):  src = os.path.join(input_dir, img)  if i < split_idx:  dst = os.path.join(train_dir, "monkeypox", img)  else:  dst = os.path.join(val_dir, "monkeypox", img)  shutil.copy2(src, dst)  print(f"Processed {len(normal_images)} normal images and {len(monkeypox_images)} monkeypox images")  return  except Exception as e:  print(f"Error organizing by extensions: {e}")  raise ValueError("Could not process dataset. Please ensure the data is organized in class folders.")  # 处理每个类别  for class_name in classes:  print(f"Processing class: {class_name}")  # 创建类别目录  train_class_dir = os.path.join(train_dir, class_name)  val_class_dir = os.path.join(val_dir, class_name)  os.makedirs(train_class_dir, exist_ok=True)  os.makedirs(val_class_dir, exist_ok=True)  # 获取该类别的所有图像  class_dir = os.path.join(input_dir, class_name)  images = [img for img in os.listdir(class_dir) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]  # 随机打乱图像  random.shuffle(images)  # 分割为训练集和验证集  split_idx = int(len(images) * split_ratio)  train_images = images[:split_idx]  val_images = images[split_idx:]  # 复制图像到相应目录  for img in tqdm(train_images, desc=f"Training {class_name}"):  src = os.path.join(class_dir, img)  dst = os.path.join(train_class_dir, img)  try:  # 检查图像是否有效  with Image.open(src) as img_obj:  # 复制文件  shutil.copy2(src, dst)  except Exception as e:  print(f"Error processing {src}: {e}")  for img in tqdm(val_images, desc=f"Validation {class_name}"):  src = os.path.join(class_dir, img)  dst = os.path.join(val_class_dir, img)  try:  # 检查图像是否有效  with Image.open(src) as img_obj:  # 复制文件  shutil.copy2(src, dst)  except Exception as e:  print(f"Error processing {src}: {e}")  print(f"Processed {class_name}: {len(train_images)} training images, {len(val_images)} validation images")  def verify_dataset(data_dir):  """  验证数据集是否有效。  参数:  - data_dir: 数据目录  """    train_dir = os.path.join(data_dir, 'train')  val_dir = os.path.join(data_dir, 'val')  if not os.path.exists(train_dir) or not os.path.exists(val_dir):  return False  train_classes = [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]  val_classes = [d for d in os.listdir(val_dir) if os.path.isdir(os.path.join(val_dir, d))]  if not train_classes or not val_classes:  return False  # 确保每个类别都有图像  for class_name in train_classes:  class_dir = os.path.join(train_dir, class_name)  images = [img for img in os.listdir(class_dir) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]  if not images:  print(f"Warning: No images found in {class_dir}")  return False  for class_name in val_classes:  class_dir = os.path.join(val_dir, class_name)  images = [img for img in os.listdir(class_dir) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]  if not images:  print(f"Warning: No images found in {class_dir}")  return False  return True  if __name__ == "__main__":  parser = argparse.ArgumentParser(description='Preprocess monkey disease dataset')  parser.add_argument('--input_dir', type=str, required=True, help='Input data directory')  parser.add_argument('--output_dir', type=str, default='./data/weather-1', help='Output data directory')  parser.add_argument('--split_ratio', type=float, default=0.8, help='Train/validation split ratio')  parser.add_argument('--verify', action='store_true', help='Only verify the dataset')  args = parser.parse_args()  if args.verify:  if verify_dataset(args.output_dir):  print("Dataset is valid!")  else:  print("Dataset is invalid or incomplete!")  else:  preprocess_dataset(args.input_dir, args.output_dir, args.split_ratio)  print("Dataset preprocessing completed!")  # 验证数据集  if verify_dataset(args.output_dir):  print("Dataset is valid and ready for training!")  else:  print("Warning: Dataset may be incomplete or invalid. Please check the data.")
> python preprocess_dataset.py --input_dir './data/weather_classification'
Processing class: cloudy
Training cloudy: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1230.06it/s]
Validation cloudy: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1273.72it/s] 
Processed cloudy: 8000 training images, 2000 validation images
Processing class: haze
Training haze: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1219.19it/s] 
Validation haze: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1239.72it/s] 
Processed haze: 8000 training images, 2000 validation images
Processing class: rainy
Training rainy: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1177.47it/s] 
Validation rainy: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1181.43it/s] 
Processed rainy: 8000 training images, 2000 validation images
Processing class: snow
Training snow: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1160.55it/s] 
Validation snow: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1170.75it/s] 
Processed snow: 8000 training images, 2000 validation images
Processing class: sunny
Training sunny: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1260.57it/s] 
Validation sunny: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1253.31it/s]
Processed sunny: 8000 training images, 2000 validation images
Processing class: thunder
Training thunder: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1321.18it/s] 
Validation thunder: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1339.66it/s] 
Processed thunder: 8000 training images, 2000 validation images
Dataset preprocessing completed!
Validation sunny: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1253.31it/s] 
Processed sunny: 8000 training images, 2000 validation images
Processing class: thunder
Training thunder: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:06<00:00, 1321.18it/s] 
Validation thunder: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1339.66it/s] 
Processed thunder: 8000 training images, 2000 validation images
Dataset preprocessing completed!
Dataset is valid and ready for training!
3. train_weather_model.py
import os  
import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader  
from torchvision import transforms, datasets  
import matplotlib.pyplot as plt  
import numpy as np  
from sklearn.metrics import confusion_matrix, classification_report  
from tqdm import tqdm  
import time  
import copy  
import argparse  # 导入我们已有的InceptionV3模型  
from inceptionv3 import InceptionV3  def get_transforms():  # 训练数据变换  train_transform = transforms.Compose([  transforms.Resize((299, 299)),  # InceptionV1 输入大小为299x299 transforms.RandomHorizontalFlip(),  transforms.RandomRotation(10),  transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  transforms.ToTensor(),  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  ])  # 验证/测试数据变换  val_transform = transforms.Compose([  transforms.Resize((299, 299)),  transforms.ToTensor(),  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  ])  return train_transform, val_transform  def load_data(data_dir, batch_size=32):  train_transform, val_transform = get_transforms()  # 假设数据集结构是data_dir下有train和val两个文件夹  train_dir = os.path.join(data_dir, 'train')  val_dir = os.path.join(data_dir, 'val')  # 如果没有val目录,则使用train目录并拆分  if not os.path.exists(val_dir):  print("Validation directory not found. Using a split of training data.")  dataset = datasets.ImageFolder(train_dir, transform=train_transform)  train_size = int(0.8 * len(dataset))  val_size = len(dataset) - train_size  train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])  # 为验证集应用正确的变换  val_dataset.dataset.transform = val_transform  else:  train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)  val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)  # 创建数据加载器  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)  # 获取类别  class_names = None  if isinstance(train_dataset, datasets.ImageFolder):  class_names = train_dataset.classes  else:  class_names = dataset.classes  print(f"Class names: {class_names}")  print(f"Training images: {len(train_dataset)}")  print(f"Validation images: {len(val_dataset)}")  return train_loader, val_loader, class_names  def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=25, scheduler=None):  since = time.time()  best_model_wts = copy.deepcopy(model.state_dict())  best_acc = 0.0  # 记录训练过程  history = {  'train_loss': [],  'train_acc': [],  'val_loss': [],  'val_acc': []  }  for epoch in range(num_epochs):  print(f'Epoch {epoch + 1}/{num_epochs}')  print('-' * 10)  # 每个epoch都有训练和验证阶段  for phase in ['train', 'val']:  if phase == 'train':  model.train()  # 设置模型为训练模式  dataloader = train_loader  else:  model.eval()  # 设置模型为评估模式  dataloader = val_loader  running_loss = 0.0  running_corrects = 0  # 遍历数据  for inputs, labels in tqdm(dataloader, desc=phase):  inputs = inputs.to(device)  labels = labels.to(device)  # 梯度清零  optimizer.zero_grad()  # 前向传播  # 只有在训练阶段才跟踪梯度  with torch.set_grad_enabled(phase == 'train'):  outputs = model(inputs)  _, preds = torch.max(outputs, 1)  loss = criterion(outputs, labels)  # 如果是训练阶段,则反向传播+优化  if phase == 'train':  loss.backward()  optimizer.step()  # 统计  running_loss += loss.item() * inputs.size(0)  running_corrects += torch.sum(preds == labels.data)  if scheduler is not None and phase == 'train':  scheduler.step()  epoch_loss = running_loss / len(dataloader.dataset)  epoch_acc = running_corrects.double() / len(dataloader.dataset)  print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')  # 记录训练和验证的损失和准确率  if phase == 'train':  history['train_loss'].append(epoch_loss)  history['train_acc'].append(epoch_acc.item())  else:  history['val_loss'].append(epoch_loss)  history['val_acc'].append(epoch_acc.item())  # 如果是最好的模型,保存模型  if phase == 'val' and epoch_acc > best_acc:  best_acc = epoch_acc  best_model_wts = copy.deepcopy(model.state_dict())  # 保存最佳模型  torch.save(model.state_dict(), 'best_weather_model.pth')  print(f'New best model saved with accuracy: {best_acc:.4f}')  print()  time_elapsed = time.time() - since  print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')  print(f'Best val Acc: {best_acc:.4f}')  # 加载最佳模型权重  model.load_state_dict(best_model_wts)  return model, history  def plot_training(history):  plt.figure(figsize=(12, 4))  plt.subplot(1, 2, 1)  plt.plot(history['train_loss'], label='Training Loss')  plt.plot(history['val_loss'], label='Validation Loss')  plt.title('Loss over epochs')  plt.xlabel('Epoch')  plt.ylabel('Loss')  plt.legend()  plt.subplot(1, 2, 2)  plt.plot(history['train_acc'], label='Training Accuracy')  plt.plot(history['val_acc'], label='Validation Accuracy')  plt.title('Accuracy over epochs')  plt.xlabel('Epoch')  plt.ylabel('Accuracy')  plt.legend()  plt.tight_layout()  plt.savefig('training_history.png')  plt.show()  def main():  parser = argparse.ArgumentParser(description='Train Weather Identification Model')  parser.add_argument('--data_dir', type=str, default='./data/weather-1', help='Data directory')  parser.add_argument('--batch_size', type=int, default=32, help='Batch size')  parser.add_argument('--num_epochs', type=int, default=20, help='Number of epochs')  parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')  parser.add_argument('--save_path', type=str, default='weather_idenfification_model.pth', help='Model save path')  args = parser.parse_args()  # 加载数据  train_loader, val_loader, class_names = load_data(args.data_dir, args.batch_size)  # 初始化模型  model = InceptionV3(num_classes=len(class_names))  model = model.to(device)  # 定义损失函数和优化器  criterion = nn.CrossEntropyLoss()  optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)  # 学习率调度器  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)  # 训练模型  model, history = train_model(  model, criterion, optimizer,  train_loader, val_loader,  num_epochs=args.num_epochs,  scheduler=scheduler  )  # 绘制训练过程  plot_training(history)  # 保存最终模型  torch.save({  'model_state_dict': model.state_dict(),  'optimizer_state_dict': optimizer.state_dict(),  'class_names': class_names  }, args.save_path)  print(f"Final model saved as '{args.save_path}'")  if __name__ == "__main__":  # 设备配置  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  print(f"Using device: {device}")  main()
Epoch 19/20
----------
train: 100%|███████████████████████████████| 1500/1500 [06:41<00:00,  3.73it/s] 
train Loss: 0.5085 Acc: 0.8099
val: 100%|█████████████████████████████████| 375/375 [00:45<00:00,  8.27it/s] 
val Loss: 0.5092 Acc: 0.8134
New best model saved with accuracy: 0.8134Epoch 20/20
----------
train: 100%|███████████████████████████████| 1500/1500 [06:43<00:00,  3.72it/s] 
train Loss: 0.5086 Acc: 0.8089
val: 100%|█████████████████████████████████| 375/375 [00:45<00:00,  8.28it/s] 
val Loss: 0.5122 Acc: 0.8088

image.png

4. predict_weather.py
import torch  
import torch.nn.functional as F  
from torchvision import transforms  
from PIL import Image  
import argparse  
import os  
import numpy as np  
import matplotlib.pyplot as plt  
from inceptionv3 import InceptionV3  class WeatherPredictor:  def __init__(self, model_path, device='auto'):  """  初始化天气预测器  Args:            model_path: 训练好的模型路径  device: 计算设备 ('cuda', 'cpu', 'auto')        """        # 设备配置  if device == 'auto':  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  else:  self.device = torch.device(device)  print(f"Using device: {self.device}")  # 加载模型  self.model, self.class_names = self._load_model(model_path)  # 数据预处理  self.transform = transforms.Compose([  transforms.Resize((299, 299)),  # InceptionV3输入尺寸  transforms.ToTensor(),  transforms.Normalize(mean=[0.485, 0.456, 0.406],  std=[0.229, 0.224, 0.225])  ])  def _load_model(self, model_path):  """加载训练好的模型"""  try:  # 尝试加载完整的checkpoint  checkpoint = torch.load(model_path, map_location=self.device)  if 'model_state_dict' in checkpoint:  # 如果是完整的checkpoint  model_state_dict = checkpoint['model_state_dict']  class_names = checkpoint.get('class_names', None)  else:  # 如果只是模型参数  model_state_dict = checkpoint  class_names = None  except Exception as e:  print(f"Error loading checkpoint: {e}")  # 尝试加载只有参数的文件  model_state_dict = torch.load(model_path, map_location=self.device)  class_names = None  # 获取类别数量  if class_names:  num_classes = len(class_names)  else:  # 如果没有保存class_names,从模型结构推断  last_layer_key = 'fc.weight'  if last_layer_key in model_state_dict:  num_classes = model_state_dict[last_layer_key].shape[0]  # 尝试从数据目录推断类别名称  print("Class names not found in model file. Attempting to infer from data directory...")  class_names = self._infer_class_names_from_data(num_classes)  if not class_names:  # 生成默认类别名称  class_names = [f'class_{i}' for i in range(num_classes)]  print(f"Could not infer class names. Using default names: {class_names}")  print("To get correct class names, please provide the data directory or retrain with updated code.")  else:  # 默认天气类别  class_names = ['cloudy', 'rain', 'shine', 'sunrise']  num_classes = len(class_names)  print(f"Loading model with {num_classes} classes: {class_names}")  # 创建模型  model = InceptionV3(num_classes=num_classes)  model.load_state_dict(model_state_dict)  model.to(self.device)  model.eval()  return model, class_names  def _infer_class_names_from_data(self, num_classes):  """从数据目录推断类别名称"""  # 常见的数据目录路径  possible_data_dirs = [  './data/weather-1/train',  './data/weather-1',  './data/train',  './data',  '../data/weather-1/train',  '../data/weather-1',  ]  for data_dir in possible_data_dirs:  if os.path.exists(data_dir):  try:  # 获取子目录名称作为类别名称  subdirs = [d for d in os.listdir(data_dir)  if os.path.isdir(os.path.join(data_dir, d))]  if len(subdirs) == num_classes:  subdirs.sort()  # 保持一致的顺序  print(f"Inferred class names from {data_dir}: {subdirs}")  return subdirs  except Exception as e:  continue  return None  def preprocess_image(self, image_path):  """预处理单张图片"""  try:  # 打开并转换图片  image = Image.open(image_path).convert('RGB')  # 应用变换  input_tensor = self.transform(image)  # 添加batch维度  input_batch = input_tensor.unsqueeze(0)  return input_batch.to(self.device), image  except Exception as e:  print(f"Error processing image {image_path}: {e}")  return None, None  def predict_single(self, image_path, show_confidence=True):  """预测单张图片"""  input_batch, original_image = self.preprocess_image(image_path)  if input_batch is None:  return None  with torch.no_grad():  outputs = self.model(input_batch)  probabilities = F.softmax(outputs, dim=1)  confidence, predicted = torch.max(probabilities, 1)  predicted_class = self.class_names[predicted.item()]  confidence_score = confidence.item()  # 获取所有类别的概率  all_probs = probabilities[0].cpu().numpy()  result = {  'predicted_class': predicted_class,  'confidence': confidence_score,  'all_probabilities': dict(zip(self.class_names, all_probs)),  'image_path': image_path  }  if show_confidence:  print(f"Image: {os.path.basename(image_path)}")  print(f"Predicted: {predicted_class} (Confidence: {confidence_score:.4f})")  print("All probabilities:")  for class_name, prob in result['all_probabilities'].items():  print(f"  {class_name}: {prob:.4f}")  print("-" * 50)  return result  def predict_batch(self, image_paths):  """批量预测多张图片"""  results = []  print(f"Predicting {len(image_paths)} images...")  for i, image_path in enumerate(image_paths):  print(f"Processing {i + 1}/{len(image_paths)}: {os.path.basename(image_path)}")  result = self.predict_single(image_path, show_confidence=False)  if result:  results.append(result)  return results  def predict_directory(self, directory_path, extensions=None):  """预测目录中的所有图片"""  if extensions is None:  extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']  image_paths = []  for filename in os.listdir(directory_path):  if any(filename.lower().endswith(ext) for ext in extensions):  image_paths.append(os.path.join(directory_path, filename))  if not image_paths:  print(f"No images found in {directory_path}")  return []  return self.predict_batch(image_paths)  def visualize_prediction(self, image_path, save_path=None):  """可视化预测结果"""  result = self.predict_single(image_path, show_confidence=False)  if not result:  return  # 读取原图  image = Image.open(image_path)  # 创建图形  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # 显示原图  ax1.imshow(image)  ax1.set_title(f'Original Image\n{os.path.basename(image_path)}')  ax1.axis('off')  # 显示预测概率  classes = list(result['all_probabilities'].keys())  probs = list(result['all_probabilities'].values())  colors = ['green' if c == result['predicted_class'] else 'skyblue' for c in classes]  bars = ax2.bar(classes, probs, color=colors)  ax2.set_title(f'Prediction: {result["predicted_class"]}\nConfidence: {result["confidence"]:.4f}')  ax2.set_ylabel('Probability')  ax2.set_ylim(0, 1)  # 添加数值标签  for bar, prob in zip(bars, probs):  height = bar.get_height()  ax2.text(bar.get_x() + bar.get_width() / 2., height + 0.01,  f'{prob:.3f}', ha='center', va='bottom')  plt.tight_layout()  if save_path:  plt.savefig(save_path, dpi=150, bbox_inches='tight')  print(f"Visualization saved to {save_path}")  plt.show()  def main():  parser = argparse.ArgumentParser(description='Weather Classification Prediction')  parser.add_argument('--model_path', type=str, required=True,  help='Path to the trained model file')  parser.add_argument('--image_path', type=str,  help='Path to a single image to predict')  parser.add_argument('--image_dir', type=str,  help='Path to directory containing images to predict')  parser.add_argument('--output_file', type=str,  help='Path to save prediction results (CSV format)')  parser.add_argument('--visualize', action='store_true',  help='Show visualization of predictions')  parser.add_argument('--device', type=str, default='auto',  choices=['auto', 'cuda', 'cpu'],  help='Device to use for inference')  args = parser.parse_args()  # 初始化预测器  predictor = WeatherPredictor(args.model_path, args.device)  results = []  # 单张图片预测  if args.image_path:  if not os.path.exists(args.image_path):  print(f"Image file not found: {args.image_path}")  return  print("Predicting single image...")  result = predictor.predict_single(args.image_path)  if result:  results.append(result)  if args.visualize:  predictor.visualize_prediction(args.image_path)  # 目录批量预测  elif args.image_dir:  if not os.path.exists(args.image_dir):  print(f"Directory not found: {args.image_dir}")  return  print("Predicting images in directory...")  results = predictor.predict_directory(args.image_dir)  else:  print("Please specify either --image_path or --image_dir")  return  # 保存结果  if args.output_file and results:  import pandas as pd  # 准备数据  data = []  for result in results:  row = {  'image_path': result['image_path'],  'predicted_class': result['predicted_class'],  'confidence': result['confidence']  }  # 添加所有类别的概率  for class_name, prob in result['all_probabilities'].items():  row[f'prob_{class_name}'] = prob  data.append(row)  # 保存为CSV  df = pd.DataFrame(data)  df.to_csv(args.output_file, index=False)  print(f"Results saved to {args.output_file}")  # 打印总结  if results:  print(f"\nSUMMARY: Processed {len(results)} images")  class_counts = {}  for result in results:  pred_class = result['predicted_class']  class_counts[pred_class] = class_counts.get(pred_class, 0) + 1  print("Prediction distribution:")  for class_name, count in sorted(class_counts.items()):  print(f"  {class_name}: {count}")  if __name__ == "__main__":  main()

预测结果:

> python predict_weather.py --model_path best_weather_model.pth --image_path ./data/weather_photos/rain/rain1.jpg --visualize
Using device: cuda
Class names not found in model file. Attempting to infer from data directory...
Inferred class names from ./data/weather-1/train: ['cloudy', 'haze', 'rainy', 'snow', 'sunny', 'thunder']
Loading model with 6 classes: ['cloudy', 'haze', 'rainy', 'snow', 'sunny', 'thunder']
Predicting single image...
Image: rain1.jpg
Predicted: rainy (Confidence: 0.7616)
All probabilities:cloudy: 0.1271haze: 0.0568rainy: 0.7616snow: 0.0465sunny: 0.0062thunder: 0.0017
--------------------------------------------------

image.png

> python predict_weather.py --model_path best_weather_model.pth --image_path ./data/0001.jpg --visualize       
Using device: cuda
Class names not found in model file. Attempting to infer from data directory...
Inferred class names from ./data/weather-1/train: ['cloudy', 'haze', 'rainy', 'snow', 'sunny', 'thunder']
Loading model with 6 classes: ['cloudy', 'haze', 'rainy', 'snow', 'sunny', 'thunder']
Predicting single image...
Image: 0001.jpg
Predicted: sunny (Confidence: 0.9979)
All probabilities:cloudy: 0.0003haze: 0.0013rainy: 0.0000snow: 0.0005sunny: 0.9979thunder: 0.0000

image.png

(三)总结
小卷积核的优势

将大的卷积核拆分成多个小的卷积核可以显著减少参数量,主要原因如下:
(参考:Everyservice - 个人博客)
参数量计算对比

以一个7×7的卷积核为例:

  • 单个7×7卷积核的参数量:7 × 7 = 49个参数
  • 用三个3×3卷积核替代:3 × (3 × 3) = 27个参数

参数减少了约45%。

具体原理

  1. 感受野相同:多个小卷积核堆叠后的感受野等于大卷积核的感受野。例如,两个3×3卷积核的感受野等于一个5×5卷积核,三个3×3卷积核的感受野等于一个7×7卷积核。
  2. 参数共享效率更高:小卷积核在多次使用时共享参数,而大卷积核需要更多独立参数来覆盖相同区域。
  3. 非线性变换增加:多个小卷积核之间可以插入激活函数,增加网络的非线性表达能力,这比单个大卷积核更有效。

实际例子

在通道数为C的情况下:

  • 一个5×5卷积层:C × 5 × 5 × C = 25C²个参数
  • 两个3×3卷积层:2 × (C × 3 × 3 × C) = 18C²个参数
关于辅助分类器

Inception V3中的辅助分类器(Auxiliary Classifier)是一种帮助深层网络训练的技术,主要用于解决梯度消失问题。

基本概念

辅助分类器是在网络的中间层添加的额外分类器,它们不是最终的输出,而是训练过程中的辅助工具。在Inception V3中,通常在网络的中间位置插入1-2个辅助分类器。

具体结构

辅助分类器的典型结构包括:

  1. 一个5×5的平均池化层(步长为3)
  2. 一个1×1的卷积层(用于降维)
  3. 两个全连接层
  4. 一个dropout层
  5. 最终的softmax分类层

主要作用

  1. 缓解梯度消失:为网络中间层提供额外的梯度信号,帮助更好地训练深层网络的前面部分。
  2. 正则化效果:辅助分类器的损失会加到总损失中,起到正则化作用,防止过拟合。
  3. 特征学习引导:强制中间层学习对分类有用的特征表示。

损失函数计算

总损失 = 主分类器损失 + 0.3 × 辅助分类器损失

权重0.3是经验值,确保辅助损失不会过度影响主要的学习目标。

使用方式

  • 训练时:辅助分类器参与损失计算和反向传播
  • 推理时:通常丢弃辅助分类器,只使用主分类器的输出

这种设计让Inception V3能够训练更深的网络,同时保持良好的收敛性和泛化能力。

http://www.xdnf.cn/news/9569.html

相关文章:

  • 华为OD机试_2025 B卷_报文响应时间(Python,100分)(附详细解题思路)
  • 区域人数异常检测算法AI智能分析网关V4构建工业/商业/工地/景区等多场景解决方案
  • 使用Nginx + Keepalived配置实现Web站点高可用方案
  • 【PhysUnits】15 类型整数基本结构体补充P1(basic.rs)
  • cs224w课程学习笔记-第12课
  • java反序列化之链子分析及利用
  • 边缘计算新基建:iVX 轻量生成模块的 ARM 架构突围
  • 程序员出海之英语-使用手册
  • Ubuntu22.04 重装后,串口无响应
  • 什么是绩效管理,如何科学实施
  • 数据标注对于模型训练的重要性
  • Python 训练营打卡 Day 38
  • 智慧场馆:科技赋能的艺术盛宴
  • Bug 背后的隐藏剧情
  • 【GESP真题解析】第 14 集 GESP 二级 2024 年 6 月编程题 1:平方之和
  • 如何提升高精度滚珠螺杆的生产效率?
  • RabbitMQ 与其他 MQ 的对比分析:Kafka/RocketMQ 选型指南(一)
  • Linux 常用命令 -md5sum【计算和校验文件的MD5哈希值】
  • web第七次课后作业--springbootWeb响应
  • 【C++基础知识】【ASAN】
  • 【207】VS2022 C++对unsigned char某一位(bit)的数据进行读写
  • dify本地部署的怎么更新新版本
  • matlab实现图像压缩编码
  • 4.8.3 利用SparkSQL统计每日新增用户
  • 微信小程序返回上一页监听
  • PG技术分享
  • 数据结构之队列实验
  • Nacos 服务注册发现案例:nacos-spring-cloud-example 详解
  • world quant教程学习二
  • 基于亚博K210开发板——物体分类测试