当前位置: 首页 > java >正文

PyTorch 中torch.einsum函数的使用详解和工程应用示例

torch.einsum 是 PyTorch 中非常强大的函数,用于 灵活定义张量的乘法、求和、转置、点乘、外积等各种线性代数操作。它源自 Einstein Summation(爱因斯坦求和约定),是一种更紧凑、可读性更强的多维操作方式。


函数原型

torch.einsum(equation, *operands)
  • equation: 字符串,定义操作规则。
  • operands: 一个或多个 Tensor,参与运算的张量。

爱因斯坦求和规则简要

equation 中:

  • 字母表示张量的维度;
  • 重复的字母表示该维度将被求和;
  • 不重复的字母表示结果张量保留的维度。

常见用途与示例

1. 向量内积(dot product)

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([0.1, 0.2, 0.3])
res = torch.einsum('i,i->', a, b)
print(res)  # 输出:1.4

解释:i,i-> 表示对所有 i 求乘积再求和。


2. 矩阵乘法(matrix multiplication)

A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B)
print(C.shape)  # torch.Size([2, 4])

相当于 torch.matmul(A, B)


3. 矩阵向量乘法

A = torch.randn(3, 4)
x = torch.randn(4)
y = torch.einsum('ij,j->i', A, x)
print(y.shape)  # torch.Size([3])

4. 批量矩阵乘法

A = torch.randn(5, 2, 3)  # batch_size=5
B = torch.randn(5, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
print(C.shape)  # torch.Size([5, 2, 4])

5. 转置(矩阵维度调换)

A = torch.randn(3, 4)
B = torch.einsum('ij->ji', A)
print(B.shape)  # torch.Size([4, 3])

6. 外积(outer product)

a = torch.tensor([1.0, 2.0])
b = torch.tensor([10.0, 20.0, 30.0])
res = torch.einsum('i,j->ij', a, b)
print(res)
# 输出形状为 (2, 3)

7. 计算每个 batch 的向量 L2 范数平方

x = torch.randn(32, 128)  # batch_size=32, dim=128
norm_sq = torch.einsum('bi,bi->b', x, x)

8. 计算注意力权重矩阵

q = torch.randn(32, 8, 64)  # query: batch, head, dim
k = torch.randn(32, 8, 64)
attn_scores = torch.einsum('bhd,bhd->bh', q, k)

工程应用示例

图邻接矩阵传播(Graph Adjacency Propagation)通常用于图神经网络(GNN)中,表示将节点特征通过邻接结构进行消息传递(message passing)或特征聚合(feature aggregation)的过程。
我们将使用 torch.einsum 实现这个过程。


1. 图邻接传播的数学表达

设:

  • A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N:邻接矩阵(可选归一化)
  • X ∈ R N × F X \in \mathbb{R}^{N \times F} XRN×F:节点特征矩阵(N 个节点,每个有 F 维特征)
  • A X AX AX:每个节点从邻居节点聚合特征

2. PyTorch 示例代码(einsum 实现)

import torch# 假设图有 4 个节点,每个节点有 3 维特征
X = torch.tensor([[1.0, 0.5, 2.0],[0.3, 1.2, 0.7],[0.8, 0.1, 1.1],[0.0, 0.3, 0.4]])  # (4, 3)# 邻接矩阵 A(可为稀疏或归一化矩阵)
A = torch.tensor([[1, 1, 0, 0],[1, 1, 1, 0],[0, 1, 1, 1],[0, 0, 1, 1]], dtype=torch.float32)  # (4, 4)# 使用 torch.einsum 进行图传播 AX
# 'ij,jk->ik':A(i,j) * X(j,k) → 输出 (i,k)
X_agg = torch.einsum('ij,jk->ik', A, X)print("聚合后的特征:")
print(X_agg)

3. 加权传播(带权重)

若有可学习的线性层 W ∈ R F × F ′ W \in \mathbb{R}^{F \times F'} WRF×F,传播过程变为:

A X W AXW AXW

代码示例:

# W 是可学习的线性变换
W = torch.nn.Linear(in_features=3, out_features=2, bias=False)# XW: 节点特征线性变换
X_transformed = W(X)  # (4, 2)# 再传播
X_out = torch.einsum('ij,jk->ik', A, X_transformed)print("传播后的新特征维度:", X_out.shape)

4. 扩展到 Batch 形式(多个图)

设:

  • X: (B, N, F) — batch_size 个图,每图 N 节点 F 维特征
  • A: (B, N, N) — batch 的邻接矩阵
# B 个图,每图 4 节点,每个节点 3 维特征
X = torch.randn(8, 4, 3)       # (B, N, F)
A = torch.eye(4).repeat(8, 1, 1)  # (B, N, N)# 批量邻接传播
X_agg = torch.einsum('bij,bjf->bif', A, X)  # 输出 (B, N, F)

5. 总结

操作einsum 表达说明
图传播(AX)'ij,jk->ik'基础图邻接传播
批量传播'bij,bjf->bif'批次图传播
权重传播AX @ W or 'ij,jk->ik' then Linear加入特征变换

为什么用 einsum

  • 替代嵌套的 permute + view + matmul
  • 更接近数学表达;
  • 性能有时更优;
  • 可用于写清晰的复杂操作,如 self-attention、卷积、图神经网络等。

小技巧

  • einsum_path 可用于优化路径选择:
torch.einsum_path('bij,bjk->bik', A, B, optimize='optimal')

总结

功能示例公式
向量点积'i,i->'
矩阵乘法'ik,kj->ij'
外积'i,j->ij'
转置'ij->ji'
batch matmul'bij,bjk->bik'
L2 norm square'bi,bi->b'
http://www.xdnf.cn/news/13462.html

相关文章:

  • QML显示图片问题解决办法
  • IDEA的git提交代码提交失败,有错误0 个文件已提交,1 个文件提交失败:
  • 双路 CPU 物理服务器租用服务
  • 鹰盾视频加密器Windows播放器禁止虚拟机运行的技术实现解析
  • 青藏高原ASTER_GDEM数据集(2011)
  • Linux C学习路线全概括及知识点笔记3-网络编程
  • AI 视频创作技术全解析:从环境搭建到实战落地​
  • 2025年的WWDC所更新的内容
  • JS 原型与原型链详解
  • mac redis以守护进程重新启动
  • MySQL之事务与视图
  • 【笔记】Kubernetes 中手动及自动化证书更换步骤及注意事项
  • 如何开启自己计算机远程桌面连接功能? 给别人或异地访问
  • 8.Vue的watch监视
  • 从sdp开始到webrtc的通信过程
  • 第二十六课:手搓梯度增强
  • 深入浅出:C++深拷贝与浅拷贝
  • Jadx(开源AVA反编译工具) v1.5.0
  • 编译线程安全的HDF5库
  • Python环境搭建竞赛技术
  • 代码训练LeetCode(29)最后一个单词的长度
  • Github月度新锐热门工具 - 202506
  • PyTorch:让深度学习像搭积木一样简单!!!
  • 邮件限流器
  • 《Redis》持久化
  • 国产linux系统(银河麒麟,统信uos)使用 PageOffice实现word 文档中的table插入新行并赋值
  • 论文略读:RegMix: Data Mixture as Regression for Language Model Pre-training
  • CATIA高效工作指南——常规配置篇(四)
  • deepbayes: VI回顾和GMM近似推断
  • 分布式事务的炼狱:Spring Cloud 微服务架构下的数据一致性保障战