当前位置: 首页 > java >正文

PyTorch 张量核心知识点

文章目录

  • PyTorch 张量核心知识点
    • 一、张量基础认知
      • 1. 张量的定义
      • 2. 张量的维度与形状
    • 二、张量创建方法
      • 1. 直接创建(基于已知数据)
      • 2. 特殊值张量
      • 3. 随机张量
      • 4. 基于已有张量创建(形状匹配)
    • 三、张量数据类型
      • 1. 常见数据类型
      • 2. 数据类型指定与转换
    • 四、张量访问与取值
      • 1. 索引访问(多维索引)
      • 2. 切片访问(范围取值)
      • 3. 单个元素提取(`item()`)
      • 4. 掩码取值(布尔索引)
    • 五、张量形状修改
      • 1. 重塑(`reshape`/`view`)
      • 2. 维度重排(`permute`/`transpose`)
      • 3. 维度压缩与扩展(`squeeze`/`unsqueeze`)
      • 4. 维度扩展(`expand`/`expand_as`)
    • 六、张量运算
      • 1. 基础算术运算
      • 2. 广播机制(Broadcast)
      • 3. 数学函数
        • (1)三角函数
        • (2)比较函数
        • (3)统计函数
      • 4. 矩阵运算
        • 1)普通矩阵乘法(2 维)
        • (2)批量矩阵乘法(3 维及以上)
      • 5. 张量操作(拼接、堆叠、拆分)
        • (1)拼接(`concat`)
        • (2)堆叠(`stack`)
        • (3)拆分(`split`/`chunk`)
        • (4)展平(`flatten`)
    • 七、其他常用操作
      • 1. 克隆(`clone()`)
      • 2. 脱离计算图(`detach()`)
    • 八、核心重点总结
    • 九、广播机制 “三步法” 总结

PyTorch 张量核心知识点

一、张量基础认知

1. 张量的定义

  • 张量(Tensor)是 PyTorch 中数据运算的基本单元,本质是多维数组,用于存储和处理高维数据。
  • PyTorch 神经网络的输入、权重、输出等均以张量形式存在,所有运算均基于张量进行。

2. 张量的维度与形状

  • 维度:张量的 “阶数”,如 0 维(标量)、1 维(向量)、2 维(矩阵)、3 维及以上(高维张量)。
  • 形状(shape/size):描述每个维度的元素个数,格式为(dim1_len, dim2_len, ..., dimN_len)
    • 例:torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]]) 的形状为 (2,2,2)(2 个 2×2 矩阵)。
  • 查看形状:tensor.shapetensor.size()(两者等价)。

二、张量创建方法

1. 直接创建(基于已知数据)

import torch
# 0 维张量(标量)
t_scalar = torch.tensor(5, dtype=torch.float)
# 1 维张量(向量)
t_vec = torch.tensor([1, 2, 3])
# 3 维张量
t_3d = torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])

2. 特殊值张量

函数作用示例
torch.zeros(shape)创建全 0 张量torch.zeros(2, 3) → 2×3 全 0
torch.ones(shape)创建全 1 张量torch.ones(3, 4) → 3×4 全 1
torch.empty(shape)创建空张量(未初始化,值随机)torch.empty(2, 2)
torch.arange(n)创建 0 到 n-1 的连续整数张量(1 维)torch.arange(6) → [0,1,2,3,4,5]

注意:empty 仅分配内存不初始化,速度快但值不可控;zeros/ones 会初始化值,更安全。

3. 随机张量

函数作用示例
torch.rand(shape)0~1 均匀分布随机数torch.rand(2,3)
torch.randn(shape)标准正态分布(均值 0,方差 1)torch.randn(2,3)
torch.randint(low, high, size)[low, high) 整数随机数torch.randint(0,5, (2,3))
torch.normal(mean, std, size)自定义正态分布(均值 mean,标准差 std)torch.normal(mean=2, std=1, size=(2,3))
  • 固定随机种子(确保结果可复现):torch.manual_seed(100)(种子值可自定义)。

4. 基于已有张量创建(形状匹配)

基于某张量的形状创建新张量,避免重复写形状参数:

t = torch.rand(2, 3)  # 原张量形状 (2,3)
t_empty_like = torch.empty_like(t)  # 空张量,形状与 t 一致
t_zeros_like = torch.zeros_like(t)  # 全 0 张量,形状与 t 一致
t_rand_like = torch.rand_like(t)    # 0~1 随机张量,形状与 t 一致

三、张量数据类型

1. 常见数据类型

类型类别具体类型说明
整数型torch.int/torch.int32标准 32 位整数
torch.int64/torch.long64 位整数(常用于索引)
torch.uint8无符号 8 位整数(0~255)
浮点型torch.float/torch.float3232 位浮点数(默认浮点类型)
torch.float64/torch.double64 位浮点数(精度更高)
布尔型torch.bool布尔值(True/False)

2. 数据类型指定与转换

  • 创建时指定:指定dtype参数

    t_int = torch.zeros(2,3, dtype=torch.int)
    t_float = torch.rand(2,3, dtype=torch.float64)
    
  • 创建后转换:

    • 方法 1:tensor.to(dtype)

      tensor.to(dtype)
      t = torch.zeros(2,3)  # 默认 float32
      t_uint8 = t.to(torch.uint8)
      
    • 方法 2:简写方法(如 double()/int()/long()

      t_double = t.double()  # 转 float64
      t_long = t.long()      # 转 int64
      

四、张量访问与取值

1. 索引访问(多维索引)

  • 格式:tensor[dim0_idx, dim1_idx, ..., dimN_idx],支持整数索引。

    t = torch.arange(24).reshape(2,3,4)  # 形状 (2,3,4)
    print(t[0,1,2])  # 取第 0 个 3×4 矩阵的第 1 行第 2 列 → tensor(6)
    

2. 切片访问(范围取值)

  • 格式:tensor[dim0_slice, dim1_slice, ...],支持 start:end:step 切片语法。

    print(t[:, 1, 1:3])  # 所有矩阵的第 1 行、第 1-2 列 → 形状 (2,2)
    

3. 单个元素提取(item()

  • 仅适用于单元素张量(如标量或形状为 (1,) 的张量),返回 Python 原生类型。

    t_single = t[0,1,2]  # 单元素张量
    print(t_single.item())  # 6(Python 整数)
    

4. 掩码取值(布尔索引)

  • 用布尔张量筛选满足条件的元素,常用于批量修改。

    t = torch.randint(0,3, (5,5))  # 5×5 整数张量
    mask = t == 0  # 布尔掩码:True 表示元素为 0 的位置
    t[mask] = -1   # 将所有为 0 的元素改为 -1
    

五、张量形状修改

1. 重塑(reshape/view

  • 作用:改变张量形状,元素总数不变(各维度长度乘积需等于原总数)。

    t = torch.arange(24)  # 形状 (24,)
    t_234 = t.reshape(2,3,4)  # 重塑为 (2,3,4)
    t_view = t.view(4,6)      # 视图方式重塑为 (4,6)
    
  • 区别:

    • reshape:可能重新分配内存(若原张量内存不连续)。
    • view:仅创建视图(共享内存,不重新分配),仅适用于内存连续的张量。
  • 便捷语法:用 -1 自动计算某维度长度(仅一个 -1 有效)

    t_auto = t.reshape(2, -1, 4)  # -1 自动计算为 3 → 形状 (2,3,4)
    

2. 维度重排(permute/transpose

  • 作用:改变维度的顺序,不改变元素值。

    • transpose(dim1, dim2):仅交换两个维度。
    • permute(dim0, dim1, ...):任意重排所有维度。
    t = torch.rand(3,4,5)  # 原形状 (3,4,5)
    t_trans = t.transpose(1,2)  # 交换 1、2 维 → 形状 (3,5,4)
    t_perm = t.permute(2,0,1)   # 重排为 (5,3,4)
    
  • 特殊:二维张量转置可直接用 tensor.T

    t_mat = torch.rand(3,4)
    t_mat_T = t_mat.T  # 转置为 (4,3)
    

3. 维度压缩与扩展(squeeze/unsqueeze

  • squeeze(dim):删除长度为 1 的维度(不指定 dim 则删除所有长度为 1 的维度)。

    t = torch.rand(1,3,1,4)  # 形状 (1,3,1,4)
    t_sq = t.squeeze()        # 删除所有长度 1 维度 → (3,4)
    t_sq0 = t.squeeze(0)      # 仅删除第 0 维 → (3,1,4)
    
  • unsqueeze(dim):在指定位置插入长度为 1 的维度。

    t = torch.rand(3,4)       # 形状 (3,4)
    t_usq0 = t.unsqueeze(0)   # 第 0 维插入 → (1,3,4)
    t_usq2 = t.unsqueeze(2)   # 第 2 维插入 → (3,4,1)
    

4. 维度扩展(expand/expand_as

  • 作用:将长度为 1 的维度扩展为指定长度(浅表复制,共享内存,不新增元素)。

    t = torch.arange(6).reshape(2,1,3)  # 形状 (2,1,3)
    t_exp = t.expand(2,4,3)             # 第 1 维从 1 扩展到 4 → (2,4,3)
    t_exp_auto = t.expand(-1,4,-1)      # -1 表示保留原长度 → 同上
    
  • expand_as(tensor):扩展为目标张量的形状(需满足广播条件)。

    t_target = torch.rand(2,4,3)
    t_exp_as = t.expand_as(t_target)  # 扩展为 (2,4,3)
    

六、张量运算

1. 基础算术运算

  • 与标量运算:张量的每个元素与标量进行运算(+、-、×、/、**、%)。

    t = torch.arange(6).reshape(2,3)  # [[0,1,2],[3,4,5]]
    print(t + 2)  # 所有元素加 2
    print(t * 3)  # 所有元素乘 3
    print(t **2)  # 所有元素平方
    
  • 同形张量运算:形状完全相同的张量,对应元素逐一运算。

    t2 = torch.tensor([[0,0,1],[2,1,0]])
    print(t + t2)  # 对应元素相加
    

2. 广播机制(Broadcast)

  • 定义:自动扩展形状不同的张量,使它们可进行元素级运算(无需显式扩展)。

  • 广播条件:两个张量从最右侧维度开始比较,每个维度满足 “长度相等” 或 “其中一个为 1”。

    x = torch.randint(0,3, (2,3,1))  # 形状 (2,3,1)
    y = torch.randint(0,3, (3,2))    # 形状 (3,2)
    print(x + y)  # 广播后 x 为 (2,3,2),y 为 (1,3,2) → 结果 (2,3,2)
    

3. 数学函数

(1)三角函数
  • 需先将角度转为弧度(torch.deg2rad())。

    angles = torch.tensor([30, 60])
    radians = torch.deg2rad(angles)
    sin_vals = torch.sin(radians)  # 正弦值
    
(2)比较函数
函数作用
torch.eq(t1, t2)逐元素判断是否相等(返回布尔张量)
torch.equal(t1, t2)判断两个张量完全相同(形状 + 值)
torch.allclose(t1, t2, atol=ε)判断数值近似相等(atol 为允许误差)
(3)统计函数
函数作用示例
tensor.max(dim)沿指定维度求最大值(返回值 + 索引)t.max(dim=0) → 按列求最大
tensor.min(dim)沿指定维度求最小值t.min(dim=1) → 按行求最小
tensor.mean(dim)沿指定维度求均值t.mean(dim=0) → 按列求均值
tensor.var(dim)沿指定维度求方差t.var(dim=1) → 按行求方差
tensor.std(dim)沿指定维度求标准差(方差开根号)t.std(dim=0) → 按列求标准差
  • 限制值范围:

    t = torch.randn(5,5)  # 正态分布随机数
    t_clamp = torch.clamp(t, -0.1, 0.1)  # 限制在 [-0.1, 0.1]
    

4. 矩阵运算

  • 1)普通矩阵乘法(2 维)
    • 要求:前一个张量的最后一维长度 = 后一个张量的倒数第二维长度
    • 函数:torch.matmul(t1, t2) 或简写 t1 @ t2torch.mm(t1, t2)
    A = torch.randint(1,4, (2,3))  # 2×3 矩阵
    B = torch.randint(1,4, (3,4))  # 3×4 矩阵
    C = A @ B  # 结果为 2×4 矩阵
    
  • 区别:matmul 支持广播,mm 仅支持 2 维张量且不广播。

(2)批量矩阵乘法(3 维及以上)
  • 作用:对批量的矩阵逐一相乘(前 N-2 维为批量维度,最后 2 维为矩阵维度)。

  • 函数:torch.bmm(t1, t2)

    A = torch.randint(1,4, (5,2,3))  # 5 个 2×3 矩阵
    B = torch.randint(1,4, (5,3,4))  # 5 个 3×4 矩阵
    C = torch.bmm(A, B)  # 结果为 5 个 2×4 矩阵 → 形状 (5,2,4)
    

5. 张量操作(拼接、堆叠、拆分)

(1)拼接(concat
  • 作用:将多个张量沿已有维度拼接(不新增维度)。

  • 要求:除拼接维度外,其他维度形状必须一致。

    A = torch.rand(2,3,4)
    B = torch.rand(2,3,4)
    C_dim0 = torch.concat([A,B], dim=0)  # 沿第 0 维拼接 → (4,3,4)
    C_dim1 = torch.concat([A,B], dim=1)  # 沿第 1 维拼接 → (2,6,4)
    
(2)堆叠(stack
  • 作用:将多个张量沿新增维度堆叠(会新增维度)。

  • 要求:所有张量形状必须完全一致。

    A = torch.rand(3,4)
    B = torch.rand(3,4)
    C_dim0 = torch.stack([A,B], dim=0)  # 新增第 0 维 → (2,3,4)
    C_dim2 = torch.stack([A,B], dim=2)  # 新增第 2 维 → (3,4,2)
    
(3)拆分(split/chunk
  • split(segment_len, dim):按 “每段长度” 拆分。

  • chunk(num_chunks, dim):按 “拆分块数” 拆分。

    t = torch.rand(3,6,5)  # 形状 (3,6,5)
    # split:每段长度 2,沿第 1 维拆分
    A1,B1,C1 = t.split(2, dim=1)  # 每段形状 (3,2,5)
    # chunk:拆分为 3 块,沿第 1 维拆分
    A2,B2,C2 = t.chunk(3, dim=1)  # 每块形状 (3,2,5)
    
(4)展平(flatten
  • 作用:将指定维度范围合并为一个维度。

  • 格式:torch.flatten(tensor, start_dim, end_dim)(默认 start_dim=0end_dim=-1)。

    t = torch.rand(2,3,4,5)
    t_flat1 = t.flatten(start_dim=1)  # 第 1-3 维展平 → (2, 60)
    t_flat2 = t.flatten(1,2)          # 第 1-2 维展平 → (2, 12, 5)
    

七、其他常用操作

1. 克隆(clone()

  • 作用:创建张量的深拷贝(新张量与原张量值相同,但内存独立)。

    t = torch.arange(10)
    t_clone = t.clone()
    t[0] = 100  # 修改原张量,克隆张量不受影响
    print(t_clone)  # 仍为 [0,1,2,...,9]
    

2. 脱离计算图(detach()

  • 作用:创建张量的浅拷贝,脱离当前计算图(仅用于推理,不参与梯度计算)。

    t = torch.arange(10, requires_grad=True)
    t_detach = t.detach()  # 脱离计算图,无梯度
    

八、核心重点总结

  1. 形状匹配:所有张量运算需确保形状兼容(广播机制可简化部分场景)。
  2. 维度操作reshape(重塑)、permute(重排)、squeeze/unsqueeze(增减维度)是高频操作。
  3. 矩阵乘法matmul(支持广播)、bmm(批量矩阵)需注意维度匹配。
  4. 内存效率viewexpand 共享内存,reshapeclone 可能重新分配内存,按需选择。

九、广播机制 “三步法” 总结

遇到任何广播场景,都可以按以下步骤判断:

  1. 补维度:给维度数少的张量左侧补 1,直到两个张量维度数一致;
  2. 比维度:从最右侧维度开始,逐维对比,每个维度需满足 “相等” 或 “其中一个为 1”;
  3. 扩维度:将所有 “长度为 1” 的维度,扩展为另一个张量对应维度的长度,最终两个张量形状完全一致。

通过以上例子,能覆盖 90% 以上的广播场景,核心是 “右对齐、补 1 维、判规则、扩长度”,多练两次就能快速判断~

http://www.xdnf.cn/news/19017.html

相关文章:

  • 引入资源即针对于不同的屏幕尺寸,调用不同的css文件
  • KubeBlocks For MySQL 云原生设计分享
  • 三大压测工具对比:Siege/ab/Wrk实战指南
  • SpringBoot系列之实现高效批量写入数据
  • 基础IO详解
  • 【前缀和】
  • Pandas的数据结构
  • 第十七章 Java基础-常用API-System
  • [p2p-Magnet] 数据模型(GORM) | DHT爬虫 | 分类器
  • React Hook+Ts+Antd+SpringBoot实现分片上传(前端)
  • 数据湖与数据仓库
  • Qt 中日志级别
  • ArcGIS+Fragstats:土地利用统计分析、景观格局指数计算与地图制图
  • Android Keystore签名文件详解与安全防护
  • AI视频生成工具全景对比:元宝AI、即梦AI、清影AI和Vidu AI
  • 【贪心 单调栈】P10334 [UESTCPC 2024] 饮料|普及+
  • 工业 5G + AI:智能制造的未来引擎
  • Day16_【机器学习建模流程】
  • 【Rust】 3. 语句与表达式笔记
  • Java HTTP 请求:Unirest 使用指南及与 HttpClient 对比
  • .Net Core Web 架构(Request Pipeline)的底层实现
  • 自己定义的模型如何用hf的from_pretrained
  • Linux(一) | 初识Linux与目录管理基础命令掌握
  • 测试题ansible临时命令模块
  • CuTe C++ 简介01,从示例开始
  • imx6ull-驱动开发篇47——Linux SPI 驱动实验
  • Electron解压缩文件
  • hive on tez为什么写表时,要写临时文件到hdfs目录
  • docker 1分钟 快速搭建 redis 哨兵集群
  • 配置nginx.conf (增加21001端口实例操作)