PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例
在 PyTorch 中,flatten()
函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。
一、基本语法
torch.flatten(input, start_dim=0, end_dim=-1)
参数说明:
参数 | 说明 |
---|---|
input | 输入张量 |
start_dim | 开始展平的维度(包含该维) |
end_dim | 结束展平的维度(包含该维) |
展平操作会把
start_dim
到end_dim
之间的维度合并成一维。
二、常见示例
示例 1:基本使用
import torchx = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]]) # shape = (2, 2, 2)out = torch.flatten(x)
print(out)
print(out.shape) # torch.Size([8])
等价于 x.view(-1)
,即将所有维度展平成一维。
示例 2:保留前维度(常见于 CNN)
x = torch.randn(10, 3, 32, 32) # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)print(out.shape) # torch.Size([10, 3072])
解释:
- 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
- 第 0 维(batch size)保留,适合连接到
nn.Linear
层
示例 3:多维展开(指定 end_dim)
x = torch.randn(2, 3, 4, 5) # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)print(out.shape) # torch.Size([2, 12, 5]) -> (3*4 = 12)
三、与 .view()
的区别
函数 | 说明 |
---|---|
view() | 更底层、需要张量是连续的,手动指定形状 |
flatten() | 更高层、更安全、自动处理维度合并,常用于模型构建中 |
四、常见用法:在模型中使用
1、示例1
import torch.nn as nnclass MyCNN(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(16, 10)def forward(self, x):x = self.conv(x)x = self.pool(x) # shape: (N, 16, 1, 1)x = torch.flatten(x, 1) # shape: (N, 16)x = self.fc(x)return x
2、示例2
下面使用了 torch.flatten()
将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。
使用 flatten()
的 CNN 训练流程(以 CIFAR-10 为例):
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):def __init__(self):super(FlattenCNN, self).__init__()self.conv = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), # 输入: [B, 3, 32, 32]nn.ReLU(),nn.MaxPool2d(2), # 输出: [B, 16, 16, 16]nn.Conv2d(16, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2) # 输出: [B, 32, 8, 8])self.fc = nn.Sequential(nn.Linear(32 * 8 * 8, 128),nn.ReLU(),nn.Linear(128, 10) # CIFAR-10 共 10 类)def forward(self, x):x = self.conv(x)x = torch.flatten(x, 1) # 👈 仅展平通道和空间维度,保留 batchx = self.fc(x)return x# ==== 2. 准备数据 ====
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ==== 4. 训练过程 ====
def train(model, loader, epochs):model.train()for epoch in range(epochs):total_loss = 0.0for images, labels in loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(loader)print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)
重点说明
使用 torch.flatten(x, 1)
的原因:
- 只展平通道、高、宽三维(保留 batch size)
- 替代
x.view(x.size(0), -1)
更安全,避免非连续张量报错 - 推荐在模型中构建更加模块化、清晰
五、三种张量展平方式:flatten()
、view()
和 reshape()
的对比
下面从功能差异、使用限制和**性能对比(benchmark)**进行三者的比较。
1、三者功能对比
函数 | 特点说明 |
---|---|
flatten() | 高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。 |
view() | 底层操作,速度快,但要求张量是连续(tensor.is_contiguous() 为 True ) |
reshape() | 更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全 |
2、代码功能对比
x = torch.randn(32, 3, 64, 64) # batch of images# flatten
f1 = torch.flatten(x, 1)# view
f2 = x.view(32, -1)# reshape
f3 = x.reshape(32, -1)print(f1.shape, f2.shape, f3.shape)
输出一致:torch.Size([32, 12288])
3、非连续张量对比(view 会报错)
x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1) # 非连续张量try:y.view(-1) # 会报错
except RuntimeError as e:print("view error:", e)print("reshape:", y.reshape(-1).shape) # reshape 正常
print("flatten:", torch.flatten(y).shape) # flatten 正常
4、性能测试(benchmark)
import torch
import timex = torch.randn(1024, 512, 28, 28)# 保证是连续的
x_contig = x.contiguous()N = 1000def benchmark(op, name):torch.cuda.synchronize()start = time.time()for _ in range(N):_ = op(x_contig)torch.cuda.synchronize()end = time.time()print(f"{name}: {(end - start)*1000:.2f} ms")benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")
示例结果(A100 GPU):
flatten(): 58.12 ms
view(): 41.76 ms
reshape(): 47.32 ms
总结:
view()
最快,但要求张量连续;flatten()
最安全但稍慢;reshape()
是折中方案。
5、 建议总结
场景 | 推荐方式 | 原因 |
---|---|---|
模型中展平 CNN 输出 | flatten() | 简洁、安全,尤其在复杂网络中 |
确保连续张量、追求速度 | view() | 性能最佳 |
张量可能非连续 | reshape() | 自动处理不连续情况,代码更鲁棒 |
六、小结
用法 | 效果 |
---|---|
torch.flatten(x) | 将所有维展平成一维 |
torch.flatten(x, 1) | 保留 batch 维,常用于 CNN |
torch.flatten(x, 1, 2) | 展平指定维度区间 |