PyTorch中diag_embed和transpose函数使用详解
torch.diag_embed
是 PyTorch 中用于将一个向量(或批量向量)**嵌入为对角矩阵(或批量对角矩阵)**的函数。它常用于图神经网络(GNN)或线性代数中生成对角矩阵。
函数原型
torch.diag_embed(input, offset=0, dim1=-2, dim2=-1)
参数解释:
input
:形状为(..., n)
的张量,表示一个或多个长度为n
的向量;offset
:对角偏移量(默认是0
,即主对角线);dim1
,dim2
:在哪两个维度上插入对角矩阵(通常保持默认即可)。
示例
示例 1:单个向量生成对角矩阵
x = torch.tensor([1, 2, 3])
out = torch.diag_embed(x)
# 输出:
# tensor([[1, 0, 0],
# [0, 2, 0],
# [0, 0, 3]])
示例 2:批量嵌入
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3)
out = torch.diag_embed(x)
# 输出 shape: (2, 3, 3)
# 第一个矩阵是 [1,2,3] 的对角形式,第二个是 [4,5,6] 的对角形式
应用场景(
degree_signal = torch.sum(corr_graph, dim=-1) # shape: (1, N)
D = torch.diag_embed(degree_signal) # shape: (1, N, N)
corr_laplacian = (D - corr_graph).squeeze(0) # shape: (N, N)
这个操作是为了构造图拉普拉斯矩阵(Laplacian):
L = D − A L = D - A L=D−A
其中:
- A A A 是图的邻接矩阵(
corr_graph
); - D D D 是度矩阵(对角矩阵,
diag_embed(degree_signal)
)。
在 PyTorch 中,transpose()
是用于交换张量中两个指定维度的函数,常用于调整张量维度顺序,特别是在矩阵运算或图神经网络等场景中。
函数格式:
torch.transpose(input, dim0, dim1)
# 或者张量对象方法形式:
input.transpose(dim0, dim1)
参数说明:
input
:输入的张量(Tensor)。dim0
:要交换的第一个维度索引。dim1
:要交换的第二个维度索引。
示例 1:二维张量(矩阵)
x = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
print(x.shape) # torch.Size([2, 2])y = x.transpose(0, 1) # 转置矩阵
print(y)
# tensor([[1, 3],
# [2, 4]])
示例 2:三维张量
x = torch.randn(2, 3, 4) # shape: (batch=2, height=3, width=4)# 交换第1维(height)和第2维(width)
y = x.transpose(1, 2) # shape: (2, 4, 3)
print(y.shape)
注意事项:
transpose()
只交换两个维度,如果要重新排列多个维度,请使用permute()
。transpose()
返回的是一个视图(view),不复制数据。
与 .T
的区别:
tensor.T
只适用于 二维张量,是transpose(0, 1)
的简写。- 多维张量请使用
transpose(dim0, dim1)
或permute()
。
示例:配合 .permute()
x = torch.randn(2, 3, 4)
# 等价于 transpose(1, 2)
x.transpose(1, 2) == x.permute(0, 2, 1) # True