当前位置: 首页 > java >正文

【实战1】手写字识别 Pytoch(更新中)

1 数据集

引用

import torch
from torch import nn ##nn创建神经网络
from torch.utils.data import DataLoader #DataLoader加载数据
from torchvision import datasets #datasets加载数据集
from torchvision.transforms import ToTensor #ToTensor将数据转换为张量

下载数据集 

# 下载数据
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)testing_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)

导入数据

# 加载数据
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(testing_data, batch_size=batch_size)# 打印第一个批次的数据形状 
for X, y in train_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

使用keras构建神经网络的方法

构建神经网络

定义神经网络类,继承自nn.Moudule

Pytorch使用nn.Sequentail对象创建神经网络

神经网络必须实现前向传播forward()方法

使用nn.Flatten对象将图片转成列列向量【对比Kares ,pytorch是面向对象的思路】

nn.to方法可以让模型运行在不同的硬件上

device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available() else "cpu")# 打印设备信息
print(f"Using {device} device")# 构建神经网络 2层 输入28*28,第一层128个神经元,激活函数relu,第二层10个神经元
class NeunalNetwork(nn.Module):def __init__(self):super().__init__()self.nmodel = nn.Sequential(nn.Linear(28*28, 128), # 输入层到隐藏层 nn.ReLU(), # 激活函数nn.Linear(128, 10) # 隐藏层到输出层)def forward(self, x):x = nn.Flatten(1)(x) # 将输入展平output = self.nmodel(x) #对模型训练return outputmodel = NeunalNetwork().to(device)
print(model)

编译神经网络

训练神经网络

评估神经网络

http://www.xdnf.cn/news/15798.html

相关文章:

  • 【no vue no bug】 npm : 无法加载文件 D:\software\nodeJS\node22\npm.ps1
  • 嵌入式硬件篇---舵机(示波器)
  • 小架构step系列20:请求和响应的扩展点
  • 解锁Phpenv:轻松搭建PHP集成环境指南
  • 使用“桥接模式“,实现跨平台绘图或多类型消息发送机制
  • 抓包工具使用教程
  • PaliGemma 2-轻量级开放式视觉语言模型
  • 【RocketMQ 生产者和消费者】- 消费者发起消息拉取请求 PullMessageService
  • ps2025下载与安装教程(附安装包) 2025最新版photoshop安装教程
  • 群组功能实现指南:从数据库设计到前后端交互,上班第二周
  • SElinux和iptables介绍
  • Kafka——Java生产者是如何管理TCP连接的?
  • MCP 协议详细分析一 initialize ping tools/list tools/call
  • C++数据结构————集合
  • 暑期训练8
  • 读书笔记:最好使用C++转型操作符
  • MCP 协议详细分析 二 Sampling
  • NX二次开发常用函数——从一个坐标系到另一个坐标系的转换(UF_MTX4_csys_to_csys )相同体坐标转化
  • Supertest(Node.js)接口测试
  • NJU 凸优化导论(9) 对偶(II)KKT条件+变形重构
  • 笔试强训——第一周
  • 阿里云服务器 CentOS 7 安装 MySQL 8.4 超详细指南
  • 2025年医疗人工智能发展现状
  • 网络基础DAY14-可靠性概念及要求+链路聚合
  • 机器学习漫画小抄 - 彩图版
  • 『 C++ 入门到放弃 』- AVL树
  • 了解.NET Core状态管理:优化技巧与常见问题解决方案
  • 暑假--作业3
  • Linux 自旋锁
  • 13.4 Meta LLaMA开源模型家族全面解析:从Alpaca到Vicuna的技术内幕