当前位置: 首页 > java >正文

动手学深度学习:手语视频在NiN模型中的测试

前言

NiN模型是在LeNet的基础上修改,提出了1x1卷积层和全局平均池化层的概念,减少了全连接所带来的参数量很多的问题。本篇在之前代码的基础上添加了模型保存,loss和acc记录以及记录模型时间等功能,所以模型后面的代码会重新记录一下。

模型

NiN模型主要的特色有1x1卷积和全局平均池化,以下是我个人的一些看法。

1x1卷积

由于再模型结尾将不再使用全连接层,如果还是原有的3x3等卷积的话就会丢失通道之间的信息,而1x1卷积在不改变图片大小的前提下,对通道进行卷积,可以解决这一问题。

全局平均池化

这个层主要是对每一个通道的图像进行池化变成1x1大小,也是取代全连接缩小像素得到要输出的类别大小的功能,如果说全连接是横向排布,不断减少到需要的数量的话(第一张图),那么全局平均池化就是竖向连接,一次性缩小到需要的形状(第二张图)。
在这里插入图片描述

在这里插入图片描述

代码

import torch.nn as nn
import os
import time
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,1),nn.ReLU())
net=nn.Sequential(
nin_block(frames_len,96,11,4,0),
nn.MaxPool2d(3,2),
nin_block(96,256,5,1,2),
nn.MaxPool2d(3,2),
nin_block(256,384,3,1,1),
nn.MaxPool2d(3,2),
nn.Dropout(0.5),
nin_block(384,len(labels),3,1,1),
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten()).to(device)
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述
学习率单独设置了一个变量

loss_fn=nn.CrossEntropyLoss()
lr=0.001
optimer=torch.optim.SGD(net.parameters(),lr=lr)#0.01会导致loss为nan

定义保存路径和模型名次,将主要需要调节的参数作为整个保存的文件夹名更易于区分。

# 初始化最小测试损失
best_test_loss = float('inf')
model_name="NiN"
epochs_num=300
save_path="./save/"+model_name+"_input_channels"+str(frames_len)+"_output_channels"+str(len(labels))+"_lr"+str(lr)+"_epochs"+str(epochs_num)+"/"
if not os.path.exists(save_path):os.makedirs(save_path)print(f"文件夹 '{save_path}' 已创建。")
else:print(f"文件夹 '{save_path}' 已存在。")
best_model_path = save_path+'model.pt'
best_onnx_path=save_path+"model.onnx"

添加计时功能,便于查看模型训练时间

train_len=len(train_iter.dataset)
all_acc=[]
all_loss=[]
test_all_acc=[]
test_all_loss=[]
start_time = time.time()
shape=None
for epoch in range(epochs_num):acc=0loss=0for x,y in train_iter:x=x.to(device)y=y.to(device)hat_y=net(x)l=loss_fn(hat_y,y)loss+=loptimer.zero_grad()l.backward()optimer.step()acc+=(hat_y.argmax(1)==y).sum()all_acc.append((acc/train_len).cpu().numpy())all_loss.append(loss.detach().cpu().numpy())
#     print(all_loss)test_acc=0test_loss=0test_len=len(test_iter.dataset)with torch.no_grad():for x,y in test_iter:x=x.to(device)y=y.to(device)shape=x.shapehat_y=net(x)test_loss+=loss_fn(hat_y,y)test_acc+=(hat_y.argmax(1)==y).sum()test_all_acc.append((test_acc/test_len).cpu().numpy())test_all_loss.append(test_loss.detach().cpu().numpy())print(f'{epoch}的test的acc{test_acc/test_len}')# 保存测试损失最小的模型if test_loss < best_test_loss:best_test_loss = test_losstorch.save(net, best_model_path)
#         dummy_input = torch.randn(shape).to(device)
#         torch.onnx.export(net, dummy_input, best_onnx_path, opset_version=11)print(f'Saved better model with Test Loss: {best_test_loss:.4f}')
end_time = time.time()
elapsed_time = end_time - start_time  # 计算耗时
print(f"程序运行了 {elapsed_time:.4f} 秒")  # 保留4位小数

在这里插入图片描述
针对loss添加了test的记录并且将图片保存起来便于以后查看

import matplotlib.pyplot as plt
plt.plot(range(1,epochs_num+1),all_loss,'.-',label='train_loss')
plt.text(epochs_num, all_loss[-1], f'{all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_loss,'.-',label='test_loss')
plt.text(epochs_num, test_all_loss[-1], f'{test_all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig(save_path+"train_loss.png")

在这里插入图片描述

acc同理处理

plt.plot(range(1,epochs_num+1),all_acc,'-',label='train_acc')
plt.text(epochs_num, all_acc[-1], f'{all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_acc,'-.',label='test_acc')
plt.text(epochs_num, test_all_acc[-1], f'{test_all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.legend()
plt.xlabel("epoch")
plt.ylabel("acc")
plt.savefig(save_path+"acc.png")

在这里插入图片描述

结论

NiN整体效果上比VGG还是要差一点的,收敛速度也很慢。但是运行时间比VGG快了快一倍,VGG花费了下图时间。
在这里插入图片描述

http://www.xdnf.cn/news/247.html

相关文章:

  • C++——C++11常用语法总结
  • 嵌入式面试常见算法题解析:数组元素移动与二分查找
  • 在 Vue 3 项目中引入 js-cookie 库
  • 打造一个 AI 面试助手:输入岗位 + 技术栈 → 自动生成面试问题 + 标准答案 + 技术考点图谱
  • 2025年03月中国电子学会青少年软件编程(Python)等级考试试卷(五级)真题
  • vue3学习笔记之属性绑定
  • 适合制作电磁铁的材料及特性
  • STL简介 + string【上】
  • 图像篡改检测算法
  • 【MATLAB代码例程】AOA与TOA结合的高精度平面地位,适用于四个基站的情况,附完整的代码
  • 万字解析TCP
  • 一次制作参考网杂志的阅读书源的实操经验总结(附书源)
  • 【无人机】电子速度控制器 (ESC) 驱动电机,常见的电调协议,PWM协议,Oneshot协议,DShot协议
  • Linux 网络接口 /sys/class/net/eth0 文件详解
  • 力扣面试150题--两数之和 和 快乐数
  • Java 2025:解锁未来5大技术趋势,Kotlin融合AI新篇
  • Server - 优雅的配置服务器 Bash 环境(.bashrc)
  • 无人机在农业中的应用与挑战!
  • 华为Pura X如何编辑图片、调整色调?图片编辑技巧、软件分享
  • git 出现 port 443 Connection timed out
  • 复现SCI图像增强(Toward fast, flexible, and robust low-light image enhancement.)
  • 【mysql】mysql疑难问题:实际场景解释什么是排它锁 当前读 快照读
  • YOLOv11改进:基于小波卷积WTConv的大感受野目标检测网络-
  • 使用 vcpkg 构建支持 HTTPS 的 libcurl 并解决常见链接错误
  • Java反射机制深度解析与应用案例
  • 第18周:对于ResNeXt-50算法的思考
  • Crawl4AI:重塑大语言模型数据供给的开源革命者
  • 前端资源加载失败后重试加载(CSS,JS等引用资源)
  • 在msys2里面编译antlr4的过程记录
  • C言雅韵集:野指针