【Einops】Einops rearrange方法详解
Einops
Einops 是一个用于张量操作的库,它使得张量的重塑(reshaping)、转置(transposing)和其他操作更加简洁和易于理解。通过使用 einops,你可以用一种更直观的方式处理多维数组,比如在机器学习、深度学习以及其他需要操作高维数据的领域中常见的张量。
rearrange
用来重新排列张量(比如 NumPy 数组或 PyTorch 张量)的维度
rearrange(x, pattern, **axes_lengths)
x
: 输入的张量(可以是任何支持的操作数,如 NumPy 数组、PyTorch 张量等)。pattern
: 一个字符串,定义了输入到输出的转换规则。它由空格分隔的不同维度名称组成,并用箭头->
分成两部分:左边是输入张量的维度描述,右边是输出张量的维度描述。**axes_lengths
: 可选参数,用于指定某些维度的具体大小,特别是当这些维度在模式中被聚合或分割时。
rearrange - pattern语法
rearrange(x, 'input_pattern -> output_pattern')
x
:是你要操作的张量。'input_pattern -> output_pattern'
:就是一个字符串,告诉函数:- 输入张量是怎么样的(左边)
- 输出张量应该长什么样(右边)
pattern的维度名称
在 pattern 里,每个维度用一个名字表示,比如:
b
表示 batch size(批次大小)c
表示 channels(通道数)h
表示 height(高度)w
表示 width(宽度)t
表示 time(时间步)- 等等...
这些名字是你自己定义的,可以随便起,只要前后一致就行。
例子
以下面这段代码为例进行讲解:
patch_size = 16
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
箭头左边
'b c h w'
这表示输入张量有 4 个维度,顺序是:batch, channel, height, width。
'b c (h s1) (w s2)'
(h s1)
表示:原来的高度维度其实是由两个部分组成的:h
和s1
(w s2)
同理:原来的宽度由w
和s2
构成
换句话说,这是一个拆分操作:
- 原来的高度 =
h * s1
- 原来的宽度 =
w * s2
所以这段 pattern 实际上是在说:
我知道输入张量的高度和宽度其实是被划分成了 h 个小块,每个小块大小是 s1;同理宽度方向是 w 个小块,每个是 s2。
箭头右边
'b (h w) (s1 s2 c)'
这部分描述的是输出张量的形状
b
:保留 batch 维度不变(h w)
:把 h 和 w 这两个维度合并成一个新维度,它的长度是h * w
(s1 s2 c)
:把 s1、s2 和 c 三个维度合并成一个新维度,长度是s1 * s2 * c
所以输出张量的维度:
(batch_size, h*w, s1*s2*c)
综合解释
✅ 假设 x 的 shape 是:(batch_size, channels, height, width)
假设 height = 128
, width = 128
, channels = 3
, patch_size = 16
那我们可以算出:
h = height / patch_size = 128 / 16 = 8
w = width / patch_size = 128 / 16 = 8
所以 pattern 中的变量值是:
h = 8
,w = 8
,s1 = 16
,s2 = 16
,c = 3
代入输出模式:
b = batch_size
(h w) = 8 * 8 = 64
(总共有 64 个 patch)(s1 s2 c) = 16 * 16 * 3 = 768
(每个 patch 被展平为 768 维向量)
最终输出的 shape 就是:
(batch_size, 64, 768)
更多操作
示例 1:转置(交换维度)
x = np.random.randn(2, 3, 4) # shape: (a, b, c)
y = rearrange(x, 'a b c -> a c b') # shape: (2, 4, 3)
说明:把第2维和第3维交换了位置。
示例 2:合并维度
x = np.random.randn(2, 3, 4, 5) # shape: (a, b, c, d)
y = rearrange(x, 'a b c d -> a b (c d)') # shape: (2, 3, 20)
说明:把 c 和 d 合并成一个维度,变成 c*d=4*5=20
示例 3:拆分维度
x = np.random.randn(2, 3, 20) # shape: (a, b, c)
y = rearrange(x, 'a b (c d) -> a b c d', d=5) # shape: (2, 3, 4, 5)
说明:把最后一个维度拆分成两个维度,其中一个是给定的 d=5
,另一个自动推导为 c=4
pattern语法规则总结
符号 | 含义 |
---|---|
-> | 分隔输入模式和输出模式 |
空格 | 分隔不同的维度名 |
() | 合并/拆分维度 |
变量 | 自定义的维度名称(如 b , c , h , w ) |
数字 | 固定尺寸(不建议直接写死,推荐用变量) |
... | 表示任意多个未命名的中间维度 |
总得来说,pattern
就是一套“输入如何拆解 → 输出如何重组”的规则,用一种类自然语言的方式写出来,让张量操作变得直观又强大。