当前位置: 首页 > java >正文

Pytorch中张量的索引和切片使用详解和代码示例

PyTorch 中张量索引与切片详解

使用前先导入:

import torch

1.基础索引(类似 Python / NumPy)

适用于低维张量:x[i]x[i, j]

x = torch.tensor([[10, 11, 12],[13, 14, 15],[16, 17, 18]])print(x[0])         # 第0行: tensor([10, 11, 12])
print(x[1][2])      # 第1行第2列: 15
print(x[2, 1])      # 第2行第1列: 17

2.切片(Slicing)

x = torch.arange(16).reshape(4, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])print(x[:2])        # 前两行
print(x[:, 1:3])    # 所有行,第1~2列
print(x[::2, ::2])  # 行列间隔为2

3.负索引

print(x[-1])        # 最后一行
print(x[:, -2:])    # 每行最后两列

4.使用 ... (Ellipsis)

当维度很多时可简化操作。

x = torch.arange(2*3*4).reshape(2, 3, 4)# 等价于 x[0, :, 2]
print(x[0, ..., 2])

5.Noneunsqueeze 增加维度

x = torch.tensor([1, 2, 3])# 增加维度(等价于 unsqueeze)
print(x[None, :].shape)     # torch.Size([1, 3])
print(x[:, None].shape)     # torch.Size([3, 1])

6. 布尔索引(Boolean Indexing)

x = torch.tensor([10, 20, 30, 40])mask = x > 25
print(mask)         # tensor([False, False,  True,  True])
print(x[mask])      # tensor([30, 40])

7. 花式索引(Fancy Indexing)

使用索引列表访问多个非连续位置。

x = torch.tensor([10, 20, 30, 40, 50])idx = torch.tensor([0, 2, 4])
print(x[idx])       # tensor([10, 30, 50])

二维花式索引:

x = torch.arange(1, 10).reshape(3, 3)
# tensor([[1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])rows = torch.tensor([0, 1, 2])
cols = torch.tensor([2, 1, 0])
print(x[rows, cols])  # [3, 5, 7]

8. 条件赋值 / where

x = torch.tensor([1, 2, 3, 4, 5])
x[x > 3] = 100
print(x)            # tensor([  1,   2,   3, 100, 100])# 条件选择
a = torch.tensor([1, 2, 3])
b = torch.tensor([10, 20, 30])
cond = torch.tensor([True, False, True])print(torch.where(cond, a, b))  # -> [1, 20, 3]

9. 高维张量索引技巧

x = torch.arange(2*3*4).reshape(2, 3, 4)# 提取第1个 batch 所有通道第2列
print(x[0, :, 2])    # shape: (3,)

10. 实例:图像张量裁剪(HWC)

img = torch.rand((3, 256, 256))  # C, H, W 格式# 裁剪中心区域
crop = img[:, 100:200, 100:200]  # shape (3, 100, 100)

11. 总结图解(结构化索引方式)

张量索引方式:
├── 基础索引(x[i], x[i,j])
├── 切片(x[start:end], x[:, idx])
├── 高维省略(x[..., -1])
├── 增维/降维(x[None, :], x.squeeze())
├── 布尔索引(x[x>val])
├── 花式索引(x[[0, 2, 4]])
├── 条件赋值(x[x > a] = b)
└── torch.where(cond, a, b)

高级应用


1. 高级花式索引(Advanced Fancy Indexing)

基本复习:

花式索引是用整张或部分张量作为索引,获取非连续元素。进阶里,张量的形状组合、广播规则非常重要。

代码示例:

import torchx = torch.arange(27).reshape(3, 3, 3)
# x shape = (3, 3, 3)# 目标:同时选取不同 batch 不同通道的元素
idx_batch = torch.tensor([0, 1, 2])   # 每个 batch 索引
idx_channel = torch.tensor([2, 1, 0]) # 每个对应通道索引
idx_row = torch.tensor([0, 1, 2])     # 对应行索引# 三个索引张量自动广播,选出:
# x[0, 2, 0], x[1, 1, 1], x[2, 0, 2]
result = x[idx_batch, idx_channel, idx_row]print(result)  # tensor([ 6, 13, 24])
  • 关键是各个索引张量形状要匹配或可广播
  • 返回值的形状取决于索引张量的形状。

2. 坐标映射索引(Indexing with Coordinate Tensors)

常用在点云、图像坐标映射,手工给定索引位置批量取值。

代码示例:

x = torch.arange(16).reshape(4, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])# 给定坐标点
coords = torch.tensor([[0, 1], [2, 3], [3, 0]])  # 三个点的坐标rows = coords[:, 0]
cols = coords[:, 1]vals = x[rows, cols]
print(vals)  # tensor([ 1, 11, 12])

torch.gather — 按索引沿指定维度收集数据

x = torch.arange(12).reshape(3, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])indices = torch.tensor([[0, 3], [2, 1], [1, 0]])
result = torch.gather(x, dim=1, index=indices)
print(result)
# tensor([[ 0,  3],
#         [ 6,  5],
#         [ 9,  8]])
  • torch.gather 需要索引张量与输入同形状,但索引值表示该维度的选取位置。

3. 高维图像张量处理技巧

假设图像张量格式为 (Batch, Channels, Height, Width),称为 BCHW。

常用操作示例:

(a) 批量裁剪 (Crop)
img = torch.randn(5, 3, 256, 256)  # 5张RGB图像# 取中心128x128块
h_start = (256 - 128) // 2
w_start = (256 - 128) // 2crop = img[:, :, h_start:h_start+128, w_start:w_start+128]  # shape (5, 3, 128, 128)
(b) 改变通道顺序
# BCHW -> BHWC
img_bhwc = img.permute(0, 2, 3, 1)
print(img_bhwc.shape)  # (5, 256, 256, 3)
© 按坐标索引批量像素点
batch_size = 2
img = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)# 取每张图(0,1)通道,指定像素点坐标
coords = torch.tensor([[1, 2], [3, 0]])  # (batch_size, 2) 像素坐标 (H, W)batch_indices = torch.arange(batch_size)
channels = torch.tensor([0, 1])  # 不同图不同通道pixels = img[batch_indices, channels, coords[:, 0], coords[:, 1]]
print(pixels)

总结:

技巧类别适用场景关键函数/概念
高级花式索引多维非连续索引,索引张量广播多张量索引广播
坐标映射索引点云坐标、图像点批量索引torch.gather, 坐标张量索引
高维图像张量处理批量裁剪、通道转换、批量像素选取permutereshape、多维切片

4.综合示例

下面以一个综合示例代码,涵盖 高级花式索引坐标映射索引,以及 高维图像张量处理,注释详尽,方便大家理解和直接跑起来。

import torchdef advanced_fancy_indexing():print("=== 高级花式索引示例 ===")x = torch.arange(27).reshape(3, 3, 3)idx_batch = torch.tensor([0, 1, 2])idx_channel = torch.tensor([2, 1, 0])idx_row = torch.tensor([0, 1, 2])# 选出 x[0,2,0], x[1,1,1], x[2,0,2]result = x[idx_batch, idx_channel, idx_row]print(result)  # tensor([ 6, 13, 24])print()def coordinate_mapping_indexing():print("=== 坐标映射索引示例 ===")x = torch.arange(16).reshape(4, 4)coords = torch.tensor([[0, 1], [2, 3], [3, 0]])  # 3个坐标点rows = coords[:, 0]cols = coords[:, 1]vals = x[rows, cols]print(f"从坐标 {coords.tolist()} 取值: {vals.tolist()}")# torch.gather示例x2 = torch.arange(12).reshape(3, 4)indices = torch.tensor([[0, 3], [2, 1], [1, 0]])gathered = torch.gather(x2, dim=1, index=indices)print(f"torch.gather 结果:\n{gathered}")print()def high_dim_image_tensor_processing():print("=== 高维图像张量处理示例 ===")# 生成一个 5张RGB图像 BCHW 格式img = torch.randn(5, 3, 256, 256)# 裁剪中心128x128h_start = (256 - 128) // 2w_start = (256 - 128) // 2crop = img[:, :, h_start:h_start+128, w_start:w_start+128]print(f"裁剪后的形状: {crop.shape}")# 通道顺序变换 BCHW -> BHWCimg_bhwc = img.permute(0, 2, 3, 1)print(f"通道转换后形状: {img_bhwc.shape}")# 批量取像素点batch_size = 2img_small = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)coords = torch.tensor([[1, 2], [3, 0]])  # 每张图像的像素坐标 (H, W)batch_indices = torch.arange(batch_size)channels = torch.tensor([0, 1])  # 两张图不同通道pixels = img_small[batch_indices, channels, coords[:, 0], coords[:, 1]]print(f"批量像素值: {pixels.tolist()}")if __name__ == "__main__":advanced_fancy_indexing()coordinate_mapping_indexing()high_dim_image_tensor_processing()

代码说明

  • advanced_fancy_indexing()
    演示多张量广播索引从三维张量中选取不规则元素。

  • coordinate_mapping_indexing()
    演示给定坐标点批量取值 + 用 torch.gather 沿某维度收集。

  • high_dim_image_tensor_processing()
    展示了高维图像张量裁剪、通道排列变换和批量像素点采样。


http://www.xdnf.cn/news/15481.html

相关文章:

  • [ROS 系列学习教程] ROS动作通讯(Action):通信模型、Hello World与拓展
  • B/S 架构通信原理详解
  • 【数据结构】单链表练习(有环)
  • C++(STL源码刨析/stack/queue/priority_queue)
  • Rocky Linux 9 源码包安装php8
  • I3C通信协议核心详解
  • 描述统计1
  • 百度移动开发面经合集
  • 【PCIe 总线及设备入门学习专栏 5.1.2 -- PCIe EP core_rst_n 与 app_rst_n】
  • Java 大视界 -- Java 大数据机器学习模型在金融风险传染路径分析与防控策略制定中的应用(347)
  • HTML网页结构(基础)
  • 使用Spring Cloud LoadBalancer报错java.lang.IllegalStateException
  • Nestjs框架: 数据库架构设计与 NestJS 多 ORM 动态数据库应用与连接池的配置
  • QTableView鼠标双击先触发单击信号
  • 项目进度与预算脱节,如何进行同步管理
  • 从0开始学习R语言--Day47--Nomogram
  • 多租户SaaS系统中设计安全便捷的跨租户流程共享
  • 文心一言开源版部署及多维度测评实例
  • 深度解析 AI 提示词工程(Prompt Engineering)
  • 【YOLOv11-目标检测】06-模型部署(C++)
  • 可微分3D高斯溅射(3DGS)在医学图像三维重建中的应用
  • gRPC实战指南:像国际快递一样调用跨语言服务 —— 解密Protocol Buffer与HTTP/2的完美结合
  • AI 增强大前端数据加密与隐私保护:技术实现与合规遵
  • 20250715武汉xx公司面试一面
  • Springboot儿童认知图文辅助系统6yhkv(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。
  • React.FC与React.Component
  • 高并发四种IO模型的底层原理
  • [Dify]--进阶3-- 如何通过插件扩展 Dify 的功能能力
  • 深入浅出 RabbitMQ-核心概念介绍与容器化部署
  • ubuntu部署kvm