当前位置: 首页 > ds >正文

VGG16训练和测试Fashion和CIFAR10

CIFAR10 epoch=50

Test Accuracy: 91.77%

VGG网络结构

VGG16(D)网络结构 5个block+5个Maxpool+3个FC+soft-max

VGG网络参数(标准 ImageNet 版 VGG16)

第一个VGG Block层(两个卷积层,两个ReLu,一个池化)

self.block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2)
)

第二个VGG Block层(两个卷积层,两个ReLu,一个池化)

self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

第三个VGG Block层(三个卷积层,三个ReLu,一个池化)

        self.block3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

第四个VGG Block层(三个卷积层,三个ReLu,一个池化)

        self.block4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

第五个VGG Block层(三个卷积层,三个ReLu,一个池化)

        self.block5 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

线性分类器层(flatten,linear,Relu,Dropout)

        self.classifier = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7*7*512, out_features=4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(in_features=4096, out_features=4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(in_features=4096, out_features=10),)

前向传播代码

    def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x

kaiming初始化权重函数

# ===== Kaiming 初始化函数 =====
def init_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.constant_(m.bias, 0)

主函数代码

if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = VGG16().to(device)print(summary(model, (1, 224, 224)))

完整代码

import torch
import torch.nn as nn
from torchsummary import summaryclass VGG16(nn.Module):def __init__(self):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Flatten(),nn.Linear(7*7*512, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 10),)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x# ===== Kaiming 初始化函数 =====
def init_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.constant_(m.bias, 0)if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = VGG16().to(device)# 应用 Kaiming 初始化model.apply(init_weights)# 打印模型结构summary(model, (1, 224, 224))
运行结果:Total params: 134,300,362

VGG网络参数(适用于CIFAR10的 VGG16)

前面的网络参数太大,一般适用于ImageNet数据集,而CIFAR10数据集只有32*32*3,因此可以重新设计模型参数,并且可以加入BatchNorm提高收敛速度和稳定性。

http://www.xdnf.cn/news/17343.html

相关文章:

  • Verilog 仿真问题:打拍失败
  • jdk动态代理如何实现
  • 对 .NET线程 异常退出引发程序崩溃的反思
  • 八股——IM项目
  • C++ 运算符重载:避免隐式类型转换的艺术
  • 译 | 在 Python 中从头开始构建 Qwen-3 MoE
  • 【ArcGIS】分区统计中出现Null值且Nodata无法忽略的问题以及shp擦除(erase)的使用——以NDVI去水体为例
  • 最新教程 | CentOS 7 下 MySQL 8 离线部署完整手册(含自动部署脚本)
  • vite项目中集成vditor文档编辑器
  • 低代码系统的技术深度:超越“可视化操作”的架构与实现挑战
  • 【机器学习篇】02day.python机器学习篇Scikit-learn基础操作
  • 疯狂星期四文案网第30天运营日记
  • 自学嵌入式 day43 中断系统
  • 数据结构与算法的认识
  • Linux 防火墙(firewalld)详解与配置
  • 【概念学习】早期神经网络
  • IPS知识点
  • spring-dubbo
  • ##Anolis OS 8.10 安装oracle19c
  • 从零开始的CAD|CAE开发: 单柱绕流+多柱绕流
  • vue封装一个cascade级联 多选 全选组件 ,原生写法Input,Checkbox,Button
  • 看不见的伪造痕迹:AI时代的鉴伪攻防战
  • Codeforces Round 987 (Div. 2)
  • 数据结构—队列和栈
  • 问题定位排查手记1 | 从Windows端快速检查连接状态
  • Java面试宝典:类加载器分层设计与核心机制解析
  • PyCharm vs. VSCode 到底哪个更好用
  • C++、STL面试题总结(二)
  • 图论(邻接表)DFS
  • SpringBoot 接入SSE实现消息实时推送的优点,原理以及实现