当前位置: 首页 > news >正文

PyTorch中diag_embed和transpose函数使用详解

torch.diag_embed 是 PyTorch 中用于将一个向量(或批量向量)**嵌入为对角矩阵(或批量对角矩阵)**的函数。它常用于图神经网络(GNN)或线性代数中生成对角矩阵。


函数原型

torch.diag_embed(input, offset=0, dim1=-2, dim2=-1)
参数解释:
  • input:形状为 (..., n) 的张量,表示一个或多个长度为 n 的向量;
  • offset:对角偏移量(默认是 0,即主对角线);
  • dim1, dim2:在哪两个维度上插入对角矩阵(通常保持默认即可)。

示例

示例 1:单个向量生成对角矩阵
x = torch.tensor([1, 2, 3])
out = torch.diag_embed(x)
# 输出:
# tensor([[1, 0, 0],
#         [0, 2, 0],
#         [0, 0, 3]])
示例 2:批量嵌入
x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
out = torch.diag_embed(x)
# 输出 shape: (2, 3, 3)
# 第一个矩阵是 [1,2,3] 的对角形式,第二个是 [4,5,6] 的对角形式

应用场景(

degree_signal = torch.sum(corr_graph, dim=-1)           # shape: (1, N)
D = torch.diag_embed(degree_signal)                     # shape: (1, N, N)
corr_laplacian = (D - corr_graph).squeeze(0)            # shape: (N, N)

这个操作是为了构造图拉普拉斯矩阵(Laplacian):

L = D − A L = D - A L=DA

其中:

  • A A A 是图的邻接矩阵(corr_graph);
  • D D D 是度矩阵(对角矩阵,diag_embed(degree_signal))。

在 PyTorch 中,transpose() 是用于交换张量中两个指定维度的函数,常用于调整张量维度顺序,特别是在矩阵运算或图神经网络等场景中。


函数格式:

torch.transpose(input, dim0, dim1)
# 或者张量对象方法形式:
input.transpose(dim0, dim1)

参数说明:

  • input:输入的张量(Tensor)。
  • dim0:要交换的第一个维度索引。
  • dim1:要交换的第二个维度索引。

示例 1:二维张量(矩阵)

x = torch.tensor([[1, 2], [3, 4]])  # shape: (2, 2)
print(x.shape)  # torch.Size([2, 2])y = x.transpose(0, 1)  # 转置矩阵
print(y)
# tensor([[1, 3],
#         [2, 4]])

示例 2:三维张量

x = torch.randn(2, 3, 4)  # shape: (batch=2, height=3, width=4)# 交换第1维(height)和第2维(width)
y = x.transpose(1, 2)  # shape: (2, 4, 3)
print(y.shape)

注意事项:

  • transpose()交换两个维度,如果要重新排列多个维度,请使用 permute()
  • transpose() 返回的是一个视图(view),不复制数据。

.T 的区别:

  • tensor.T 只适用于 二维张量,是 transpose(0, 1) 的简写。
  • 多维张量请使用 transpose(dim0, dim1)permute()

示例:配合 .permute()

x = torch.randn(2, 3, 4)
# 等价于 transpose(1, 2)
x.transpose(1, 2) == x.permute(0, 2, 1)  # True

http://www.xdnf.cn/news/588331.html

相关文章:

  • 小白的进阶之路系列之三----人工智能从初步到精通pytorch计算机视觉详解上
  • vue2使用pdfmake
  • Qt无边框界面添加鼠标事件
  • 吃透 Golang 基础:数据结构之切片
  • 实现了TCP的单向通信
  • 【数据库】-2 mysql基础语句(上)
  • 旋转编码器计次 红外对射传感器计次小实验及其相关库函数详解 (江协科技)
  • 第四章:YOLOv11 实战应用与开发指南
  • LeetCode 404.左叶子之和的迭代求解:栈结构与父节点定位的深度解析
  • 力扣.H指数力扣.字母异位词力扣.289生命游戏力扣452.用最小数量的箭引爆气球力扣.86分隔链表力扣.轮转数组
  • 高等数学-常微分方程
  • 国产三维CAD皇冠CAD(CrownCAD)建模教程:交流发电机
  • 推荐一个Excel与实体映射导入导出的C#开源库
  • 手写简单的tomcat
  • (泛函分析)线性算子连续必有界的证明
  • GraphRAG使用
  • 动态规划(七)——子数组系列(求和问题)
  • labview实现将百分制分数转换为等级制分数
  • Vue 3 官方 Hooks 的用法与实现原理
  • ai外呼平台:AnKo打造高效多模型服务体验!
  • labview实现LED流水灯的第二种方法
  • 每日算法刷题计划day13 5.22:leetcode不定长滑动窗口最短/最小1道题+求子数组个数越长越合法2道题,用时1h
  • 学习vue3:跨组件通信(provide+inject)
  • vscode include总是报错
  • Ubuntu24.04 LTS安装java8、mysql8.0
  • 【VScode】python初学者的有力工具
  • Labview使用报表工具
  • linux二进制安装mysql:
  • 遥控器处理器与光纤通信技术解析
  • 深入理解指针part1