当前位置: 首页 > backend >正文

PyTorch_指定运算设备 (包含安装 GPU 的 PyTorch)

PyTorch默认会将张量创建在 CPU 控制的内存中,即:默认的运算设备为 CPU。我们也可以将张量创建在 GPU 上,能够利用对于矩阵计算的优势加快模型训练。

将张量移动到 GPU 上有两种方法:

  1. 使用 cuda 方法
  2. 直接在 GPU 上创建张量
  3. 使用 to 方法指定设备

安装含有 GPU 的 PyTorch

通过这个可以判断电脑里是否已经安装 CUDA

import torch print(torch.cuda.is_available()) # 判断是否有可用的 GPU 设备

如果结果输出是 False,说明设备里没有 GPU 的 PyTorch 的版本。

可以通过 PyTorch官网来安装含有 GPU 的 PyTorch 的版本。

安装 CUDA 12.8 版本的pip命令

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

这样就安装好含有 GPU 的 PyTorch 了。


代码

import torch # 使用 cuda 方法
def test01():data = torch.tensor([10, 20, 30])print("存储设备:", data.device)# 将张量移动到 GPU 设备上data = data.cuda()print("存储设备:", data.device)# 将张量从 GPU 移动到 CPU 设备上data = data.cpu()print("存储设备:", data.device)# 直接将张量创建在指定设备上
def test02():data = torch.tensor([10, 20, 30], device='cuda')print("存储设备:", data.device)# 将张量移动到 CPU 设备上data = data.cpu()print("存储设备:", data.device)# 使用 to 方法
def test03():data = torch.tensor([10, 20, 30])print("存储设备:", data.device)# 将张量移动到 GPU 设备上data = data.to('cuda')print("存储设备:", data.device)# 将张量从 GPU 移动到 CPU 设备上data = data.to('cpu')print("存储设备:", data.device)# 注意点:张量存储在不同设备上的张量不能够直接运算
def test04():data1 = torch.tensor([10,20,30])data2 = torch.tensor([1,2,3], device='cuda')# RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!# data1 = data1.to('cuda') # 这样可以解决了data = data1 + data2 print(data)# 如果你的电脑上安装 pytorch 不是 gpu 版本的,或者电脑本身没有 gpu 设备环境# 否则下面的调用 cuda 函数的代码会报错# data1 = data1.cuda() if __name__ == "__main__":test04() 
http://www.xdnf.cn/news/3715.html

相关文章:

  • C++八股--5--设计模式--适配器模式,代理模式,观察者模式
  • 大数据:驱动技术创新与产业转型的引擎
  • 【RocketMQ NameServer】- NettyEventExecutor 处理 Netty 事件
  • 网格不迷路:用 CSS 网格生成器打造完美布局
  • Linxu基本操作
  • 单片机裸机环境下临界区保护
  • Golang WaitGroup 用法 源码阅读笔记
  • # LeetCode 1007 行相等的最少多米诺旋转
  • 动态规划-1137.第N个泰波那契数-力扣(LeetCode)
  • 【iview】es6变量结构赋值(对象赋值)
  • 【LLaMA-Factory实战】1.3命令行深度操作:YAML配置与多GPU训练全解析
  • 轻量级RTSP服务模块:跨平台低延迟嵌入即用的流媒体引擎
  • 从融智学视域快速回顾世界历史和主要语言文字最初历史证据(列表对照分析比较)
  • Vue实现成绩增删案例
  • C++ 中的继承
  • JSON 处理笔记
  • npm pnpm yarn 设置国内镜像
  • 数据库原理与应用实验二 题目七
  • PowerShell安装Chocolatey
  • 哈希函数详解(SHA-2系列、SHA-3系列、SM3国密)案例:构建简单的区块链——密码学基础
  • Python刷题:流程控制(下)
  • PowerPC架构详解:定义、应用及特点
  • 【PostgreSQL数据分析实战:从数据清洗到可视化全流程】1.1 数据库核心概念与PostgreSQL技术优势
  • C++负载均衡远程调用学习之 Dns-Route关系构建
  • 代码随想录算法训练营Day43
  • 超预期!淘宝闪购提前开放全国全量,联合饿了么扭转外卖战局
  • 美丽天天秒链动2+1源码(新零售商城搭建)
  • P4314 CPU 监控 Solution
  • YOLO旋转目标检测之ONNX模型推理
  • 上位机知识篇---粗细颗粒度