当前位置: 首页 > news >正文

【Einops】Einops rearrange方法详解

Einops

Einops 是一个用于张量操作的库,它使得张量的重塑(reshaping)、转置(transposing)和其他操作更加简洁和易于理解。通过使用 einops,你可以用一种更直观的方式处理多维数组,比如在机器学习、深度学习以及其他需要操作高维数据的领域中常见的张量。

rearrange

用来重新排列张量(比如 NumPy 数组或 PyTorch 张量)的维度

rearrange(x, pattern, **axes_lengths)
  • x: 输入的张量(可以是任何支持的操作数,如 NumPy 数组、PyTorch 张量等)。
  • pattern: 一个字符串,定义了输入到输出的转换规则。它由空格分隔的不同维度名称组成,并用箭头->分成两部分:左边是输入张量的维度描述,右边是输出张量的维度描述。
  • **axes_lengths: 可选参数,用于指定某些维度的具体大小,特别是当这些维度在模式中被聚合或分割时。

rearrange - pattern语法

rearrange(x, 'input_pattern -> output_pattern')
  • x:是你要操作的张量。
  • 'input_pattern -> output_pattern':就是一个字符串,告诉函数:
    • 输入张量是怎么样的(左边)
    • 输出张量应该长什么样(右边)

pattern的维度名称

在 pattern 里,每个维度用一个名字表示,比如:

  • b 表示 batch size(批次大小)
  • c 表示 channels(通道数)
  • h 表示 height(高度)
  • w 表示 width(宽度)
  • t 表示 time(时间步)
  • 等等...

这些名字是你自己定义的,可以随便起,只要前后一致就行。

例子

以下面这段代码为例进行讲解:

patch_size = 16
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)

箭头左边

'b c h w'

这表示输入张量有 4 个维度,顺序是:batch, channel, height, width。

'b c (h s1) (w s2)'
  • (h s1) 表示:原来的高度维度其实是由两个部分组成的:hs1
  • (w s2) 同理:原来的宽度由 ws2 构成

换句话说,这是一个拆分操作

  • 原来的高度 = h * s1
  • 原来的宽度 = w * s2

所以这段 pattern 实际上是在说:

我知道输入张量的高度和宽度其实是被划分成了 h 个小块,每个小块大小是 s1;同理宽度方向是 w 个小块,每个是 s2。

箭头右边

'b (h w) (s1 s2 c)'

这部分描述的是输出张量的形状

  • b:保留 batch 维度不变
  • (h w):把 h 和 w 这两个维度合并成一个新维度,它的长度是 h * w
  • (s1 s2 c):把 s1、s2 和 c 三个维度合并成一个新维度,长度是 s1 * s2 * c

所以输出张量的维度:

(batch_size, h*w, s1*s2*c)

综合解释

✅ 假设 x 的 shape 是:(batch_size, channels, height, width)

假设 height = 128, width = 128, channels = 3, patch_size = 16

那我们可以算出:

  • h = height / patch_size = 128 / 16 = 8
  • w = width / patch_size = 128 / 16 = 8

所以 pattern 中的变量值是:

  • h = 8, w = 8, s1 = 16, s2 = 16, c = 3

代入输出模式:

  • b = batch_size
  • (h w) = 8 * 8 = 64(总共有 64 个 patch)
  • (s1 s2 c) = 16 * 16 * 3 = 768(每个 patch 被展平为 768 维向量)

最终输出的 shape 就是:

(batch_size, 64, 768)

更多操作

示例 1:转置(交换维度)

x = np.random.randn(2, 3, 4)  # shape: (a, b, c)
y = rearrange(x, 'a b c -> a c b')  # shape: (2, 4, 3)

说明:把第2维和第3维交换了位置。


示例 2:合并维度

x = np.random.randn(2, 3, 4, 5)  # shape: (a, b, c, d)
y = rearrange(x, 'a b c d -> a b (c d)')  # shape: (2, 3, 20)

说明:把 c 和 d 合并成一个维度,变成 c*d=4*5=20


示例 3:拆分维度

x = np.random.randn(2, 3, 20)  # shape: (a, b, c)
y = rearrange(x, 'a b (c d) -> a b c d', d=5)  # shape: (2, 3, 4, 5)

说明:把最后一个维度拆分成两个维度,其中一个是给定的 d=5,另一个自动推导为 c=4

pattern语法规则总结

符号含义
->分隔输入模式和输出模式
空格分隔不同的维度名
()合并/拆分维度
变量自定义的维度名称(如 b, c, h, w
数字固定尺寸(不建议直接写死,推荐用变量)
...表示任意多个未命名的中间维度

总得来说,pattern 就是一套“输入如何拆解 → 输出如何重组”的规则,用一种类自然语言的方式写出来,让张量操作变得直观又强大。

http://www.xdnf.cn/news/503911.html

相关文章:

  • C# 创建线程的方式
  • 一字典两世界:优雅移除 `NSDictionary` 指定键的最佳实践
  • 编程基础:什么是变量
  • 《 C语言中const修饰指针变量的用法与解析》
  • 解决米勒补偿右边零点的方法
  • 【蓝桥杯省赛真题51】python石头运输 第十五届蓝桥杯青少组Python编程省赛真题解析
  • mcp学习笔记
  • day 28
  • ECS/GEM是半导体制造业的标准通信协议中host和equipment的区别是什么,在交互过程中,如何来定位角色谁为host,谁为equipment
  • Spring Boot 中 MyBatis 与 Spring Data JPA 的对比介绍
  • 【Python 算法零基础 3.递推】
  • 【C语言】链接与编译(编译环境 )
  • 配置ssh服务-ubuntu到Windows拷贝文件方法
  • Java Records:简洁的数据建模新方式
  • ubuntu 24.04安装ros1 noetic
  • 历史数据分析——中证白酒
  • 数据库3——视图及安全性
  • 计算机网络体系结构深度解析:从理论到实践的全面梳理
  • 电动调节 V 型球阀:工业流体控制的全能解决方案-耀圣
  • 高考AI试题查询系统
  • 网络切片:给用户体验做“私人定制”的秘密武器
  • 80. Java 枚举类 - 使用枚举实现单例模式
  • 自制操作系统(三、文件系统实现)
  • 8天Python从入门到精通【itheima】-14~16
  • 【PhysUnits】4.2 Integer Trait
  • c/c++的opencv的轮廓匹配初识
  • 提升Qt应用性能--全面解析关键技术与策略
  • C++性能测试工具——Vtune的使用
  • BC 范式与 4NF
  • 全局异常处理:如何优雅地统一管理业务异常