当前位置: 首页 > ops >正文

PyTorch 张量(Tensor)详解:从基础到实战

1. 引言

在深度学习和科学计算领域,张量(Tensor) 是最基础的数据结构。PyTorch 作为当前最流行的深度学习框架之一,其核心计算单元就是张量。与 NumPy 的 ndarray 类似,PyTorch 张量支持高效的数值计算,但额外提供了 GPU 加速 和 自动微分(Autograd) 功能,使其成为构建和训练神经网络的理想选择。

本文将全面介绍 PyTorch 张量的核心概念、基本操作、高级特性及实际应用,帮助读者掌握张量的使用方法,并理解其在深度学习中的作用。

2. 什么是张量?

张量是多维数组的泛化,可以表示不同维度的数据:

  • 0D 张量(标量):单个数值,如 torch.tensor(5)

  • 1D 张量(向量):一维数组,如 torch.tensor([1, 2, 3])

  • 2D 张量(矩阵):二维数组,如 torch.tensor([[1, 2], [3, 4]])

  • 3D+ 张量(高阶张量):如 RGB 图像(3D)、视频数据(4D)等

PyTorch 张量的主要特点:

  1. 支持 GPU 加速:可无缝切换 CPU/GPU 计算。

  2. 自动微分:用于神经网络的反向传播。

  3. 动态计算图:更灵活的模型构建方式(与 TensorFlow 1.x 的静态计算图不同)。

3. 张量的创建与初始化

3.1 从 Python 列表或 NumPy 数组创建

import torch
import numpy as np# 从列表创建
t1 = torch.tensor([1, 2, 3])  # 1D 张量
t2 = torch.tensor([[1, 2], [3, 4]])  # 2D 张量# 从 NumPy 数组创建
arr = np.array([1, 2, 3])
t3 = torch.from_numpy(arr)  # 共享内存(修改一个会影响另一个)

3.2 特殊初始化方法

# 全零张量
zeros = torch.zeros(2, 3)  # 2x3 的零矩阵# 全一张量
ones = torch.ones(2)  # [1., 1.]# 随机张量
rand_uniform = torch.rand(2, 2)  # 0~1 均匀分布
rand_normal = torch.randn(2, 2)  # 标准正态分布# 类似现有张量的形状
x = torch.tensor([[1, 2], [3, 4]])
x_like = torch.rand_like(x)  # 形状与 x 相同,值随机

4. 张量的基本属性

每个 PyTorch 张量都有以下关键属性:

x = torch.rand(2, 3, dtype=torch.float32, device="cuda")print(x.shape)      # 形状: torch.Size([2, 3])
print(x.dtype)      # 数据类型: torch.float32
print(x.device)     # 存储设备: cpu / cuda
print(x.requires_grad)  # 是否启用梯度计算(用于 Autograd)

4.1 数据类型(dtype)

PyTorch 支持多种数据类型:

  • torch.float32(默认)

  • torch.int64

  • torch.bool(布尔张量)

可以通过 .to() 方法转换:

x = torch.tensor([1, 2], dtype=torch.float32)
y = x.to(torch.int64)  # 转换为整型

4.2 设备(CPU/GPU)

PyTorch 允许张量在 CPU 或 GPU 上运行:

if torch.cuda.is_available():device = torch.device("cuda")x = x.to(device)  # 移动到 GPUy = y.to("cuda")  # 简写方式

5. 张量的基本运算

5.1 算术运算

a = torch.tensor([1, 2])
b = torch.tensor([3, 4])# 加法
c = a + b  # 等价于 torch.add(a, b)# 乘法(逐元素)
d = a * b  # [3, 8]# 矩阵乘法
mat_a = torch.rand(2, 3)
mat_b = torch.rand(3, 2)
mat_c = torch.matmul(mat_a, mat_b)  # 或 mat_a @ mat_b

5.2 形状操作

x = torch.rand(4, 4)# 改变形状(类似 NumPy 的 reshape)
y = x.view(16)  # 展平为一维张量
z = x.view(2, 8)  # 调整为 2x8# 转置
x_t = x.permute(1, 0)  # 行列交换# 扩维 / 压缩
x_expanded = x.unsqueeze(0)  # 增加一个维度(1x4x4)
x_squeezed = x_expanded.squeeze()  # 去除大小为1的维度

5.3 索引与切片

x = torch.rand(3, 4)# 取第一行
row = x[0, :]# 取前两列
cols = x[:, :2]# 布尔索引
mask = x > 0.5
filtered = x[mask]  # 返回满足条件的元素

6. 自动微分(Autograd)

PyTorch 的 autograd 模块支持自动计算梯度,适用于反向传播:

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x  # 计算图构建
y.backward()  # 反向传播
print(x.grad)  # dy/dx = 2x + 3 → 7.0

6.1 禁用梯度计算

with torch.no_grad():y = x * 2  # 不记录梯度

7. 张量与 NumPy 的互操作

PyTorch 张量可以无缝转换为 NumPy 数组:

# Tensor → NumPy
a = torch.rand(2, 2)
b = a.numpy()  # 共享内存(修改一个会影响另一个)# NumPy → Tensor
c = np.array([1, 2])
d = torch.from_numpy(c)  # 共享内存

8. 实际应用示例

8.1 线性回归(手动实现)

# 数据准备
X = torch.rand(100, 1)
y = 3 * X + 2 + 0.1 * torch.randn(100, 1)# 初始化参数
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 训练
lr = 0.01
for epoch in range(100):y_pred = w * X + bloss = ((y_pred - y) ** 2).mean()loss.backward()  # 计算梯度with torch.no_grad():w -= lr * w.gradb -= lr * b.gradw.grad.zero_()b.grad.zero_()print(f"w: {w.item()}, b: {b.item()}")

8.2 张量在 CNN 中的应用

import torch.nn as nn# 模拟输入(batch_size=1, channels=3, height=32, width=32)
input_tensor = torch.rand(1, 3, 32, 32)# 定义一个简单的 CNN
model = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(16 * 15 * 15, 10)  # 假设输出 10 类
)output = model(input_tensor)
print(output.shape)  # torch.Size([1, 10])

9. 总结

PyTorch 张量是深度学习的基础数据结构,支持:

  • 多维数组计算(类似 NumPy)

  • GPU 加速(大幅提升计算速度)

  • 自动微分(简化神经网络训练)

  • 动态计算图(灵活调试模型)

掌握张量的基本操作是学习 PyTorch 的关键步骤。建议读者通过官方文档和实际项目加深理解,逐步掌握张量的高级用法(如广播机制、高级索引等)。

http://www.xdnf.cn/news/19291.html

相关文章:

  • 【深度学习】配分函数:近似最大似然与替代准则
  • python复杂代码如何让ide自动推导提示内容
  • 编写Linux下usb设备驱动方法:disconnect函数中要完成的任务
  • More Effective C++ 条款20:协助完成返回值优化(Facilitate the Return Value Optimization)
  • 每日算法题【栈和队列】:栈和队列的实现、有效的括号、设计循环队列
  • [软考中级]嵌入式系统设计师—考核内容分析
  • Redis持久化之AOF(Append Only File)
  • Java基础知识(十二)
  • 8.31【Q】CXL-DMSim:
  • vue3+vite+ts 发布npm 组件包
  • Deep Think with Confidence:llm如何进行高效率COT推理优化
  • 第24章学习笔记|用正则表达式解析文本文件(PowerShell 实战)
  • zkML-JOLT——更快的ZK隐私机器学习:Sumcheck +Lookup
  • Pytest 插件介绍和开发
  • leetcode 260 只出现一次的数字III
  • COLA:大型语言模型高效微调的革命性框架
  • 免费电脑文件夹加密软件
  • 基于Adaboost集成学习与SHAP可解释性分析的分类预测
  • 【K8s】整体认识K8s之存储--volume
  • 在win服务器部署vue+springboot + Maven前端后端流程详解,含ip端口讲解
  • Transformer架构三大核心:位置编码(PE)、前馈网络(FFN)和多头注意力(MHA)。
  • 学习Python中Selenium模块的基本用法(12:操作Cookie)
  • TFS-2005《A Possibilistic Fuzzy c-Means Clustering Algorithm》
  • 使用 Python 自动化检查矢量面数据的拓扑错误(含导出/删除选项)
  • 算法题(196):最大异或对
  • 特殊符号在Html中的代码及常用标签格式的记录
  • Qt组件布局的经验
  • 线程池、锁策略
  • 机器视觉opencv教程(四):图像颜色识别与颜色替换
  • Linux中的ss命令