当前位置: 首页 > ops >正文

PyTorch的基础概念和复杂模型的基本使用

文章目录

    • 一、PyTorch基础概念
    • 二、复杂模型的学习使用

一、PyTorch基础概念

  1. 张量(Tensor)操作
    • 张量是PyTorch中的基本数据结构,类似于NumPy的数组,但支持GPU加速
    • 常见操作包括创建张量、张量运算、索引、切片等
import torch# 创建张量
x = torch.randn(3, 4)
y = torch.zeros(3, 4)# 张量运算
z = x + y
  1. 自动求导(Autograd)
    • PyTorch的自动求导系统可以自动计算梯度
    • 通过requires_grad=True启用梯度计算
# 启用自动求导
x = torch.randn(3, 4, requires_grad=True)# 计算损失
y = x * 2
loss = y.sum()# 反向传播
loss.backward()
  1. 计算图
    • PyTorch使用动态计算图(Define-by-Run)的方式
    • 每次前向传播都会构建一个新的计算图

二、复杂模型的学习使用

  1. 神经网络模块(torch.nn)
    • torch.nn提供了构建神经网络所需的各种组件
    • 主要包括各种层(如线性层、卷积层)、激活函数、损失函数等
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x
  1. 卷积神经网络(CNN)
    • 适用于图像处理任务
    • 包含卷积层、池化层等
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.fc1 = nn.Linear(12*12*64, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(-1, 12*12*64)x = F.relu(self.fc1(x))x = self.fc2(x)return x
  1. 循环神经网络(RNN)
    • 适用于序列数据处理任务
    • 包括RNN、LSTM、GRU等变体
class RNNModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNNModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])return out
  1. 训练流程
    • 数据加载:使用DataLoaderDataset加载数据
    • 模型定义:定义神经网络结构
    • 损失函数:选择合适的损失函数(如交叉熵损失)
    • 优化器:选择优化器(如Adam)并传入模型参数
    • 训练循环:执行前向传播、计算损失、反向传播和参数更新
from torch.utils.data import DataLoader, TensorDataset# 创建数据集
x_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)# 创建模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练循环
for epoch in range(10):for inputs, targets in dataloader:outputs = model(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()
  1. 模型保存与加载
    • 使用torch.save()torch.load()保存和加载模型
# 保存模型
torch.save(model.state_dict(), "model.pth")# 加载模型
model = Net()
model.load_state_dict(torch.load("model.pth"))
http://www.xdnf.cn/news/15828.html

相关文章:

  • Facebook 开源多季节性时间序列数据预测工具:Prophet 快速入门 Quick Start
  • macOs上交叉编译ffmpeg及安装ffmpeg工具
  • 测试中的bug
  • 基于深度学习的自然语言处理:构建情感分析模型
  • urllib.parse.urlencode 的使用详解
  • AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年7月20日第144弹
  • Uniapp 纯前端台球计分器开发指南:能否上架微信小程序 打包成APP?
  • 安全信息与事件管理(SIEM)系统架构设计
  • 【前端】懒加载(组件/路由/图片等)+预加载 汇总
  • AI绘画生成东汉末年赵云全身像的精细提示词
  • 四、多频技术与复杂场景处理
  • 基于卷积傅里叶分析网络 (CFAN)的心电图分类的统一时频方法
  • SpringBoot3集成MapstructPlus
  • GaussDB select into和insert into的用法
  • 基于智慧经营系统的学校住宿登记报表分析与应用探究-毕业论文—仙盟创梦IDE
  • Qt--Widget类对象的构造函数分析
  • 上电复位断言的自动化
  • 网络安全初级(前端页面的编写分析)
  • Java 递归方法详解:从基础语法到实战应用,彻底掌握递归编程思想
  • C++STL系列之list
  • Spring Boot 第一天知识汇总
  • UE5多人MOBA+GAS 26、为角色添加每秒回血回蓝(番外:添加到UI上)
  • redis-plus-plus安装与使用
  • 【vue-7】Vue3 响应式数据声明:深入理解 reactive()
  • 敏捷开发的历史演进:从先驱实践到全域敏捷(1950s-2025)
  • ubuntu 24.04 xfce4 钉钉输入抢焦点问题
  • XSS的学习笔记
  • ChatIM项目语音识别安装与使用
  • 拓展面试题之-rabbitmq面试题
  • [Python] -项目实战8- 构建一个简单的 Todo List Web 应用(Flask)