当前位置: 首页 > java >正文

实战:用 PyTorch 复现一个 3 层全连接网络,训练 MNIST,达到 95%+ 准确率

1. 使用 Anaconda 创建一个新环境,包括 python 和 与你显卡对应的 torch

2. PyCharm(2025.1.3.1)绑定 Conda 环境-CSDN博客

3. 

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 一次给模型看多少张图片
BATCH_SIZE = 64
# 把全部训练数据重复看多少遍
EPOCHS = 10
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)# 原始数据集中,一张 MNIST 图片的形状是 (1, 28, 28) ← 1 个通道(灰度),高 28,宽 28。
# 当 DataLoader 按 batch_size=64 打包后,它把 64 张这样的图片堆在一起,形成一个新的 4 维张量,形状变成 (64, 1, 28, 28)
# shuffle = True 的作用:在每个 epoch 开始时,把训练集里的 60 000 张图片顺序彻底打乱一次。
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE)# 搭建神经网络:把图片拉成一条长条 → 过 128 个神经元 → 再过 64 个神经元 → 最后给出 10 个数字的得分
class Net(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(784, 128), nn.ReLU(),nn.Linear(128, 64),  nn.ReLU(),nn.Linear(64, 10))def forward(self, x):return self.net(x)model = Net().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)# 训练
for epoch in range(1, EPOCHS + 1):model.train()pbar = tqdm(train_loader, desc=f"Epoch {epoch}")for x, y in pbar:x, y = x.to(DEVICE), y.to(DEVICE)optimizer.zero_grad()loss = criterion(model(x), y)loss.backward()optimizer.step()pbar.set_postfix(loss=loss.item())model.eval()
correct = total = 0
with torch.no_grad():for x, y in test_loader:x, y = x.to(DEVICE), y.to(DEVICE)pred = model(x).argmax(1)correct += (pred == y).sum().item()total += y.size(0)
print(f"Test Accuracy: {100*correct/total:.2f}%")

4. 运行

http://www.xdnf.cn/news/17633.html

相关文章:

  • IoT/透过oc_lwm2m/boudica150 源码中的AT指令序列,分析NB-IoT接入华为云物联网平台IoTDA的工作机制
  • Java使用“Microsoft Print To PDF”打印时如何指定输出路径
  • Vue 利用el-table和el-pagination组件,简简单单实现表格前端分页
  • AI时代基于云原生的 CI/CD 基础设施 Tekton
  • Dubbo从入门到实战:分布式服务开发指南
  • USB 基本描述符
  • 视频播放器哪个好用?视频播放器PotPlayer,KMP Player
  • 下一个排列 的 思路总结
  • 从零开始的云计算生活——项目实战容器化
  • 标准IO详解(fgets、gets、fread、fwrite、fseek 等应用)
  • Java 包装类简单认识泛型
  • 《深度解构:React与Redux构建复杂表单的底层逻辑与实践》
  • C#使用EPPlus读写Excel
  • ubuntu20.04交叉编译vlc3.0.21 x64 windows版本
  • 大模型落地:AI 技术重构工作与行业的底层逻辑
  • Pytest 全流程解析:执行机制与报告生成实战指南
  • java 插入式注解的打开方式!
  • MySQL,Redis重点面试题
  • SQL179 每个6/7级用户活跃情况
  • Spring Framework源码解析——BeanPostProcessor
  • 【学习嵌入式day-22-Linux软件编程-IO】
  • SpringBoot集成支付宝二维码支付接口详解
  • Python3.10 + Firecrawl 下载 Markdown 文档:构建高效通用文章爬虫
  • 不同FPGA开发板系统移植步骤
  • Chrome插件开发【Service Worker练手小项目】
  • 【LeetCode刷题集】--排序(三)
  • 【智能的起源】人类如何模仿,简单的“刺激-反应”机制 智能的核心不是记忆,而是发现规律并能迁移到新场景。 最原始的智能:没有思考,只有简单条件反射
  • Mamba 原理汇总2
  • AI(2)-神经网络(激活函数)
  • 支持小语种的在线客服系统,自动翻译双方语言,适合对接跨境海外客户