当前位置: 首页 > news >正文

PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例

在 PyTorch 中,flatten() 函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。


一、基本语法

torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

参数说明
input输入张量
start_dim开始展平的维度(包含该维)
end_dim结束展平的维度(包含该维)

展平操作会把 start_dimend_dim 之间的维度合并成一维。


二、常见示例

示例 1:基本使用

import torchx = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]])  # shape = (2, 2, 2)out = torch.flatten(x)
print(out)
print(out.shape)  # torch.Size([8])

等价于 x.view(-1),即将所有维度展平成一维。


示例 2:保留前维度(常见于 CNN)

x = torch.randn(10, 3, 32, 32)  # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)print(out.shape)  # torch.Size([10, 3072])

解释:

  • 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
  • 第 0 维(batch size)保留,适合连接到 nn.Linear

示例 3:多维展开(指定 end_dim)

x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)print(out.shape)  # torch.Size([2, 12, 5]) -> (3*4 = 12)

三、与 .view() 的区别

函数说明
view()更底层、需要张量是连续的,手动指定形状
flatten()更高层、更安全、自动处理维度合并,常用于模型构建中

四、常见用法:在模型中使用

1、示例1

import torch.nn as nnclass MyCNN(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(16, 10)def forward(self, x):x = self.conv(x)x = self.pool(x)              # shape: (N, 16, 1, 1)x = torch.flatten(x, 1)       # shape: (N, 16)x = self.fc(x)return x

2、示例2

下面使用了 torch.flatten() 将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。


使用 flatten() 的 CNN 训练流程(以 CIFAR-10 为例)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):def __init__(self):super(FlattenCNN, self).__init__()self.conv = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1),  # 输入: [B, 3, 32, 32]nn.ReLU(),nn.MaxPool2d(2),                # 输出: [B, 16, 16, 16]nn.Conv2d(16, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2)                 # 输出: [B, 32, 8, 8])self.fc = nn.Sequential(nn.Linear(32 * 8 * 8, 128),nn.ReLU(),nn.Linear(128, 10)              # CIFAR-10 共 10 类)def forward(self, x):x = self.conv(x)x = torch.flatten(x, 1)  # 👈 仅展平通道和空间维度,保留 batchx = self.fc(x)return x# ==== 2. 准备数据 ====
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ==== 4. 训练过程 ====
def train(model, loader, epochs):model.train()for epoch in range(epochs):total_loss = 0.0for images, labels in loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(loader)print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)

重点说明

使用 torch.flatten(x, 1) 的原因:

  • 只展平通道、高、宽三维(保留 batch size)
  • 替代 x.view(x.size(0), -1) 更安全,避免非连续张量报错
  • 推荐在模型中构建更加模块化、清晰

五、三种张量展平方式:flatten()view()reshape() 的对比

下面从功能差异使用限制和**性能对比(benchmark)**进行三者的比较。


1、三者功能对比

函数特点说明
flatten()高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。
view()底层操作,速度快,但要求张量是连续(tensor.is_contiguous()True
reshape()更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全

2、代码功能对比

x = torch.randn(32, 3, 64, 64)  # batch of images# flatten
f1 = torch.flatten(x, 1)# view
f2 = x.view(32, -1)# reshape
f3 = x.reshape(32, -1)print(f1.shape, f2.shape, f3.shape)

输出一致:torch.Size([32, 12288])


3、非连续张量对比(view 会报错)

x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1)  # 非连续张量try:y.view(-1)  # 会报错
except RuntimeError as e:print("view error:", e)print("reshape:", y.reshape(-1).shape)   # reshape 正常
print("flatten:", torch.flatten(y).shape)  # flatten 正常

4、性能测试(benchmark)

import torch
import timex = torch.randn(1024, 512, 28, 28)# 保证是连续的
x_contig = x.contiguous()N = 1000def benchmark(op, name):torch.cuda.synchronize()start = time.time()for _ in range(N):_ = op(x_contig)torch.cuda.synchronize()end = time.time()print(f"{name}: {(end - start)*1000:.2f} ms")benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")

示例结果(A100 GPU):

flatten(): 58.12 ms
view():    41.76 ms
reshape(): 47.32 ms

总结view()最快,但要求张量连续;flatten()最安全但稍慢;reshape()是折中方案。


5、 建议总结

场景推荐方式原因
模型中展平 CNN 输出flatten()简洁、安全,尤其在复杂网络中
确保连续张量、追求速度view()性能最佳
张量可能非连续reshape()自动处理不连续情况,代码更鲁棒

六、小结

用法效果
torch.flatten(x)将所有维展平成一维
torch.flatten(x, 1)保留 batch 维,常用于 CNN
torch.flatten(x, 1, 2)展平指定维度区间

http://www.xdnf.cn/news/1203733.html

相关文章:

  • dapp前端⾯试题
  • 【QT搭建opencv环境】
  • <RT1176系列11>DMAMUX解读
  • Spring AI 1.0 提供简单的 AI 系统和服务
  • TS面试题
  • 分布式IO详解:2025年分布式无线远程IO采集控制方案选型指南
  • simple-mock-proxy,自动拾取后端接口数据,生成本地mock接口与数据
  • idea启动java应用报错
  • keepalived原理及实战部署
  • vue怎么实现导入excel表功能
  • 最新!Polkadot 更新 2025 路线图
  • C++-关于协程的一些思考
  • ERC20 和 XCM Precompile|详解背后技术逻辑
  • 【Kotlin】如何实现静态方法?(单例类、伴生对象、@JvmStatic)
  • Android中应用进程中Binder创建机制
  • VUE2 学习笔记11 脚手架
  • 从0到500账号管理:亚矩阵云手机多开组队与虚拟定位实战指南
  • 数据结构之顺序表链表栈
  • 分享一个脚本,从mysql导出数据csv到hdfs临时目录
  • CFIHL: 水培生菜的多种叶绿素 a 荧光瞬态图像数据集
  • 雷达系统设计学习:自制6GHz FMCW Radar
  • 深入解析 Spring 获取 XML 验证模式的过程
  • 可以组成网络的服务器 - 华为OD统一考试(JavaScript 题解)
  • 速度革命 Kingston FURY PCIe 5.0 NVMe装机体验
  • 第四章:分析 Redis 性能高原因和核心字符串类型命令
  • 15-C语言:第15天笔记
  • Nginx 四层(stream)反向代理 + DNS 负载均衡
  • Java面试深度剖析:从JVM到云原生的技术演进
  • JVM 内存共享区域详解
  • 解决cordova编译安卓提示Cloud not find XXXX.aar