当前位置: 首页 > backend >正文

torch.gather()和torch.sort

torch.gather()

def semantic_neighbor(x, index):
'''
假设x.shape=[B,L,C]=[2,3,4]   index.shape=[B,L]=[2,3]
x = torch.tensor([[[1, 2, 3, 4],    # 样本1的3个元素,每个元素4维特征[5, 6, 7, 8],[9, 10, 11, 12]],[[13, 14, 15, 16], # 样本2的3个元素[17, 18, 19, 20],[21, 22, 23, 24]]
])# 索引张量 index (B=2, L=3)
index = torch.tensor([[1, 0, 1],  # 样本1的重组索引[2, 1, 0]   # 样本2的重组索引
])'''dim = index.dim()#dim=2assert x.shape[:dim] == index.shape, "x ({:}) and index ({:}) shape incompatible".format(x.shape, index.shape)for _ in range(x.dim() - index.dim()):index = index.unsqueeze(-1)'''x.index=[2,3]index = torch.tensor([[[1],[0], [1]], [[2], [1], [0]]  ])'''index = index.expand(x.shape)'''x.index=[2,3,4]index = torch.tensor([[[1,1,1,1],[0,0,0,0], [1,1,1,1]], [[2,2,2,2], [1,1,1,1], [0,0,0,0]]  ])'''shuffled_x = torch.gather(x, dim=dim - 1, index=index)'''tensor([[[ 5,  6,  7,  8],  # 来自原始位置1[ 1,  2,  3,  4],  # 来自原始位置0[ 5,  6,  7,  8]], # 来自原始位置1[[21, 22, 23, 24],  # 来自原始位置2[17, 18, 19, 20],  # 来自原始位置1[13, 14, 15, 16]]  # 来自原始位置0
])'''return shuffled_x'''
另一个简单的示例:
源张量(3x4矩阵)
x = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])索引张量(2x3矩阵)
index = torch.tensor([[0, 1, 2],[2, 1, 0]])沿dim=0(行方向)收集
out = torch.gather(x, dim=0, index=index)结果:
[[1,  6, 11],  # 取x[0][0], x[1][1], x[2][2][9,  6,  3]]  # 取x[2][0], x[1][1], x[0][2]]
'''

x.sort()
x_sort_values, x_sort_indices = torch.sort(detached_index, dim=-1, stable=False)

  • torch.sort:对 detached_index 沿 dim=-1(即 n 维度)进行排序。
  • detached_index=[[2,0,1,0]]那么detached_index 排序后的值是 [[0, 0, 1, 2]](即 x_sort_values)。
  • x_sort_indices[[1, 3, 2, 0]],表示:
    • 排序后的第0个元素来自原始位置1(值是0),
    • 第1个元素来自原始位置3(值是0),
    • 第2个元素来自原始位置2(值是1),
    • 第3个元素来自原始位置0(值是2)。
http://www.xdnf.cn/news/8250.html

相关文章:

  • 火语言UI组件--控件函数调用
  • 免费开源的图片分割小工具
  • RT-Thread源码阅读(1)——基本框架
  • 通过云服务器实现异地组网 部署WireGuard
  • 【机器学习】 关于外插修正随机梯度方法的数值实验
  • 听脑AI:革新沟通方式,开启高效信息时代
  • 核实发票的真实性与合法性-发票查验接口-虚假发票防范
  • 关于Newtonsoft版本不兼容问题处理
  • sentinel滑动时间窗口算法详解
  • 系统性能分析基本概念(3) : Tuning Efforts
  • imuerrset
  • PT8P2104触控型8Bit MCU
  • 【Django Serializer】一篇文章详解 Django 序列化器
  • deep-rtsp 摄像头rtsp配置工具
  • 高频与超高频RFID读写器技术应用差异解析
  • 论文解读: 2018-Detection of spam reviews: a sentiment analysis approach
  • 宝尊电商一季度净收入21亿元 品牌管理收入同比大增
  • 冲刺卷软考总结-案例分析
  • 地信GIS专业关于学习、考研、就业方面的一些问题答疑
  • Windows系统下Docker安装青龙面板
  • 常见高危端口解析:网络安全中的“危险入口”
  • 101个α因子#15
  • CentOS7安装 PHP-FPM 7.4
  • 2025海外短剧CPS系统开发指南:高付费市场解析与增速全景图
  • 【CSS】九宫格布局
  • openEuler 22.03 LTS-SP3 系统安装 docker 26.1.3、docker-compose
  • 微信小程序之Promise-Promise初始用
  • 笔记:将一个文件服务器上的文件(一个返回文件数据的url)作为另一个http接口的请求参数
  • 重读《人件》Peopleware -(11)Ⅱ 办公环境 Ⅳ 插曲:生产力测量与不明飞行物
  • Nginx核心功能