当前位置: 首页 > ai >正文

PyTorch——优化器(9)

优化器根据梯度调整参数,以达到降低误差

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)# 定义卷积神经网络模型
class TY(nn.Module):def __init__(self):super(TY, self).__init__()# 构建网络结构:3个卷积层+池化层组合,2个全连接层self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5MaxPool2d(2),                   # 最大池化,步长2Conv2d(32, 32, 5, padding=2),   # 第二层卷积MaxPool2d(2),                   # 第二次池化Conv2d(32, 64, 5, padding=2),   # 第三层卷积MaxPool2d(2),                   # 第三次池化Flatten(),                      # 将多维张量展平为向量Linear(1024, 64),               # 全连接层,输入1024维,输出64维Linear(64, 10),                 # 输出层,10个类别对应10个输出)def forward(self, x):# 定义前向传播路径x = self.model1(x)return x# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)# 训练20个完整轮次
for epoch in range(20):running_loss = 0.0  # 初始化本轮累计损失# 遍历数据加载器中的每个批次for data in dataloader:imgs, targets = data  # 获取图像和标签outputs = ty(imgs)    # 前向传播result_loss = loss(outputs, targets)  # 计算损失optim.zero_grad()     # 梯度清零,防止累积result_loss.backward()  # 反向传播计算梯度optim.step()          # 更新模型参数running_loss += result_loss  # 累加损失值# 打印本轮训练的累计损失print(f"Epoch {epoch+1}, Loss: {running_loss}")

http://www.xdnf.cn/news/11942.html

相关文章:

  • 近几年字节飞书测开部分面试题整理
  • 【计网】SW、GBN、SR、TCP
  • 深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏
  • Linux——TCP和UDP
  • 6月14日开班,ESG 合规分析师招生通知
  • FreeRTOS,MicroPython,区别与联系
  • 新制作文件系统占满:Error writing to file - write (28: No space left on device)
  • 雷卯针对易百纳 海思Hi3519AV100开发板防雷防静电方案
  • 虚拟机无法开启-关掉虚拟化
  • ROS中的里程计与IMU的消息类型解读
  • 深入解析异步爬虫中的协程原理:从概念到工程实践
  • c++对imu的角速度积分得到表示旋转四元数
  • 半导体热电技术方案领跑者「富信」×企企通,构建采购数字化升级
  • 【操作系统】基础回顾(一)
  • 解决IDE编译JAVA项目时出现的OOM异常问题
  • LeetCode[513]找树左下角的值
  • C语言基础(11)【函数1】
  • FreeRTOS、Zephyr、RT-Thread,区别与联系
  • 第八部分:第一节 - 初识 React:构建交互式点餐界面骨架
  • 《射频识别(RFID)原理与应用》期末复习 RFID第一章 射频识别技术概论(知识点总结+习题巩固)
  • 2025年计算机科学与网络安全国际会议(CSNS 2025)
  • VSCode主题设计大赛解析与实践指南
  • win10打包的exe在win7运行不了
  • 【Linux】线程同步
  • 《AI角色扮演反诈技术解析:原理、架构与核心挑战》
  • UDP与TCP的区别是什么?
  • 第八部分:第三节 - 事件处理:响应顾客的操作
  • Nginx 文件目录结构总览
  • 10. MySQL索引
  • 泛型编程技巧——使用std::enable_if实现按类型进行条件编译​