当前位置: 首页 > ds >正文

从代码学习深度学习 - 转置卷积 PyTorch版

文章目录

  • 前言
  • 基本操作
  • 填充、步幅和多通道
    • 填充 (Padding)
    • 步幅 (Stride)
    • 多通道
  • 总结


前言

在卷积神经网络(CNN)的大家族中,我们熟悉的卷积层和汇聚(池化)层通常会降低输入特征图的空间维度(高度和宽度)。然而,在许多应用场景中,例如图像的语义分割(需要对每个像素进行分类)或生成对抗网络(GAN)中的图像生成,我们反而需要增加特征图的空间维度,即进行上采样。

转置卷积(Transposed Convolution),有时也被不那么准确地称为反卷积(Deconvolution),正是实现这一目标的关键操作。它能够将经过下采样的低分辨率特征图恢复到较高的分辨率,或者在生成模型中从低维噪声逐步生成高分辨率图像。

本文将通过具体的 PyTorch 代码示例,带您一步步理解转置卷积的基本原理、填充(Padding)、步幅(Stride)以及在多通道情况下的应用。

完整代码:下载连接

基本操作

让我们从最基础的转置卷积开始。假设我们有一个 2x2 的输入张量,并使用一个 2x2 的卷积核,步幅为1,没有填充。转置卷积的操作过程可以直观地理解为:将输入张量的每个元素作为标量,与卷积核相乘,得到多个中间结果;然后,将这些中间结果按照输入元素在原张量中的位置进行“放置”和叠加,从而得到最终的输出张量。

其核心思想可以看作是常规卷积操作的一种“逆向”映射,但它并非严格意义上的数学逆运算。

下图形象地展示了这个过程:

在这里插入图片描述

图中,输入是 2x2,卷积核是 2x2。

  1. 输入张量的左上角元素(0)与整个卷积核相乘,结果放置在输出张量的左上角。
  2. 输入张量的右上角元素(1)与整个卷积核相乘,结果向右移动一格放置。
  3. 输入张量的左下角元素(2)与整个卷积核相乘,结果向下移动一格放置。
  4. 输入张量的右下角元素(3)与整个卷积核相乘,结果向右和向下各移动一格放置。
  5. 所有这些放置的张量在重叠区域进行元素相加,得到最终的 3x3 输出。

输出张量的高度 (H_out) 和宽度 (W_out) 可以通过以下公式计算(当步幅为1,无填充时):

  • H_out = H_in + H_kernel - 1
  • W_out = W_in + W_kernel - 1

下面我们用代码来实现这个基本操作:

import torch
from torch import nndef transposed_convolution(input_tensor, kernel):"""实现转置卷积(反卷积)操作参数:input_tensor: 输入张量,维度为 (input_height, input_width)kernel: 卷积核,维度为 (kernel_height, kernel_width)返回:output_tensor: 转置卷积结果,维度为 (input_height + kernel_height - 1, input_width + kernel_width - 1)"""# 获取卷积核的高度和宽度,维度分别为 scalarkernel_height, kernel_width = kernel.shape# 初始化输出张量,维度为 (input_height + kernel_height - 1, input_width + kernel_width - 1)output_tensor = torch.zeros((input_tensor.shape[0] + kernel_height - 1, input_tensor.shape[1] + kernel_width - 1))# 对输入张量中的每个元素进行处理for i in range(input_tensor.shape[0]):  # 遍历输入张量的行for j in range(input_tensor.shape[1]):  # 遍历输入张量的列# 对于输入张量中的每个元素,将其与卷积核相乘,然后加到输出张量的对应区域# input_tensor[i, j] 是标量,维度为 ()# kernel 维度为 (kernel_height, kernel_width)# 输出区域 output_tensor[i:i+kernel_height, j:j+kernel_width] 维度为 (kernel_height, kernel_width)output_tensor[i:i + kernel_height, j:j + kernel_width] += input_tensor[i, j] * kernelreturn output_tensor# 示例使用
# 创建输入张量X,维度为 (2, 2)
X = torch.tensor([[0.0, 1.0], [2.0, 3.0]])# 创建卷积核K,维度为 (2, 2)
K = torch.tensor([
http://www.xdnf.cn/news/5524.html

相关文章:

  • Oracle 通过 ROWID 批量更新表
  • QT6 源(93)篇三:阅读与注释共用体类 QVariant 及其源代码,本类支持比较运算符 ==、!=。
  • Docker Compose 的历史和发展
  • Python实用工具:pdf转doc
  • flutter 项目工程文件夹组织结构
  • 新手在使用宝塔Linux部署前后端分离项目时可能会出现的问题以及解决方案
  • Linux-TCP套接字编程简易实践:实现EchoServer与远程命令执行及自定义协议(反)序列化
  • 【JavaWeb+后端常用部件】
  • Day 5:Warp高级定制与自动化
  • 足式机器人的全身模型预测控制
  • 常用设计模式
  • 一种混沌驱动的后门攻击检测指标
  • GC垃圾回收
  • vector的大小
  • Java开发经验——阿里巴巴编码规范经验总结2
  • (2025)图文解锁RAG从原理到代码实操,代码保证可运行
  • 自学嵌入式 day 17- c语言-第11章 结构体与共用体 第12章 位运算
  • 深入浅出之STL源码分析5_类模版实例化与特化
  • RAG与语义搜索:让大模型成为测试工程师的智能助手
  • DVWA靶场Cryptography模块medium不看原码做法
  • Python时间模块
  • MySQL 从入门到精通(二):DML 数据操作与 DQL 数据查询详解
  • Python项目75:PyInstaller+Tkinter+subprocess打包工具1.0(安排 !!)
  • 阿里云OSS-服务端加签直传说明/示例(SpringBoot)
  • Python数据分析案例75——基于图神经网络的交通路段流量时间序列预测
  • navicat 如何导出数据库表 的这些信息 字段名 类型 描述
  • fota移植包合入后编译验证提示:File verification failed
  • Java线程池深度解析:从使用到原理全面掌握
  • KTOR for windows:無文件落地HTTP服务扫描工具
  • 【Bootstrap V4系列】学习入门教程之 组件-表单(Forms)高级用法(二)