当前位置: 首页 > ds >正文

PyTorch学习之:torch.gather是什么?

torch.gather的定义:

torch.gather 是 PyTorch 中的一个张量操作函数,其作用是根据指定的维度dim)和索引张量index),从输入张量(input)中收集元素,生成一个与索引张量形状相同的输出张量。总体来说,就是维度dim和索引张量index决定一个收集数的规则,然后,基于这个规则从输入张量中获取需要的元素。

核心部分:

1.输入张量input):

  • 任意形状的张量。

2.索引张量index):

  • 形状必须与输入张量在除 dim 外的其他维度上一致。

  • 索引值必须在输入张量 dim 维度的有效范围内(即 0 到 size(dim)-1)。

3.输出张量output):

  • 形状与索引张量相同。

  • 每个元素的值由以下规则确定:

output[i][j][k] = input[i][index[i][j][k]][k]  # 当 dim=1 时

举例详解:

示例 1:二维张量,dim=1

import torchinput = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)output = torch.gather(input, dim=1, index=index)
print(output)

输出

tensor([[1, 1],[4, 3]])

 解释

输入是一个2x2的矩阵,因为dim是1,所以我们参考下面的公式:

output[i][j] = input[i][index[i][j]]  # 当 dim=1 时

对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][1](即为4)。

对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[1][0](即为3)。

所以,最后的结果为:

tensor([[1, 1],[4, 3]])

 示例 2:二维张量,dim=0

import torchinput = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)output = torch.gather(input, dim=0, index=index)
print(output)

输出

tensor([[1, 2],[3, 2]])

 解释

输入是一个2x2的矩阵,因为dim是0,所以我们参考下面的公式:

output[i][j] = input[index[i][j]][j]  # 当 dim=0 时

对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][1](即为2)。

对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][0](即为3)。

对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[0][1](即为2)。

所以,最后的结果为:

tensor([[1, 2],[3, 2]])

http://www.xdnf.cn/news/8003.html

相关文章:

  • 海康NVR录像回放SDK原始流转FLV视频流:基于Java的流媒体转码(无需安装第三方插件ffmpeg)
  • 远程访问家里的路由器:异地访问内网设备或指定端口网址
  • 芯片分享之X5045PI性能介绍
  • Backbone
  • Typescript 教程
  • Baklib智启企业AI知识管理
  • MySQL 主从复制搭建全流程:基于 Docker 与 Harbor 仓库
  • 杂记10---ldd获取依赖so名称并导出txt文件
  • 数字电子技术基础(六十二)——使用Multisim软件绘制边沿触发的D触发器和JK触发器
  • 2025年 PMP 6月 8月 专题知识
  • Python数据分析基础
  • LangChain入门和应用#1
  • 工商总局可视化模版-Echarts的纯HTML源码
  • CMake跨平台编译生成:从理论到实战
  • 现代计算机图形学Games101入门笔记(二十一)
  • 【Linux安装与维护】
  • 深入理解C#实例构造函数:对象初始化的关键
  • 动态规划3、悟到核心
  • 【DB2】SQL1639N 处理
  • 建立java项目
  • 免费iOS签名的能使用吗?
  • 【钱包协议】:WalletConnect 详解
  • 一步步解析 HTTPS
  • 网络安全管理之钓鱼演练应急预案
  • PCB设计教程【入门篇】——电路分析基础-元件数据手册
  • Nginx核心服务
  • 【机器学习基础】机器学习与深度学习概述 算法入门指南
  • R语言速查表
  • 什么是瞬态动力学?
  • 从芯片互连到机器人革命:英伟达双线出击,NVLink开放生态+GR00T模型定义AI计算新时代