当前位置: 首页 > ops >正文

CNN - 卷积层

2.卷积层
1. 卷积基本原理

卷积是通过卷积核在输入图像上滑动,对每个局部区域进行点对点乘积求和,生成特征图的过程。其核心是提取图像的局部特征(如边缘、纹理等),每个卷积核可视为一种 “特征探测器”。

2. 关键参数及影响
  • 卷积核(Kernel/Filter):大小通常为奇数(1×1、3×3、5×5 等),决定局部感知野范围。
  • 步长(Stride):卷积核每次滑动的像素数。
    • 影响:步长增大,特征图变小,计算量减少,但可能丢失信息;步长过小则重复计算多,效率低。
  • 填充(Padding):在输入图像边缘添加 0 值像素。
    • 作用:保持特征图大小(通过适当填充可使输入输出尺寸一致),保护边缘特征不被过度压缩。
3. 特征图大小计算公式

给定输入大小W、卷积核大小F、步长S、填充P,输出特征图大小N为:\(N = \frac{W - F + 2P}{S} + 1\) 示例:输入 5×5,卷积核 3×3,步长 1,填充 1 → 输出\((5-3+2×1)/1 +1 = 5\),即保持 5×5 大小。

4. 多通道卷积(核心结论)
  • 输入图像含多通道(如 RGB 三通道)时,卷积核需与输入通道数相同(如 3 通道输入对应 3 通道卷积核)。
  • 计算过程:每个通道的卷积核与对应输入通道卷积,结果逐点相加,生成 1 个特征图。
  • 多卷积核:若使用K个卷积核,会生成K个特征图(每个卷积核提取不同特征)。
  • 核心结论:
    • 输入通道数 = 卷积核通道数;
    • 卷积核个数 = 输出特征图通道数。
5. 卷积的优势
  • 参数共享:单个卷积核在图像上滑动时参数复用,大幅减少参数量(远少于全连接层)。例如:32×32×3 的图像用 10 个 5×5 卷积核,参数量仅为\(10×(5×5×3 + 1) = 760\)。
  • 局部特征提取:天然适配图像的空间局部相关性,擅长捕捉局部结构(如边缘、角落)。
6. PyTorch 卷积操作实践
  • 输入格式要求:卷积层输入需为NCHW格式(批次大小、通道数、高度、宽度),需通过permute(转换通道维度)和unsqueeze(添加批次维度)处理原始图像(通常为HWC格式)。
  • 核心 APInn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),其中:
    • in_channels:输入通道数(如 RGB 图为 3,灰度图为 1);
    • out_channels:卷积核个数(即输出特征图通道数)。
    • kernel_size:卷积核
    • stride:步长
    • padding:填充
7. 实验验证

        通过虚拟仿真工具可直观展示多通道卷积过程:输入多通道图像→多通道卷积核分别卷积→结果求和 + 偏置→生成输出特征图。实验验证了 “输入通道数决定卷积核通道数,卷积核个数决定输出通道数” 的核心逻辑,且可生成对应代码(如 PyTorch 实现),实现理论与实践结合。

示例1:

import os.pathimport torch
from matplotlib import pyplot as plt
from torch import nn# 构建图像相对路径
# os.path.dirname(__file__):当前脚本所在目录
image_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '彩色.png'))
img_data = plt.imread(image_path)
print(img_data.shape)# 转换图像形状以适应Pytorch卷积输入层要求
# 原来是hwc
# 1.transpose(2, 0, 1):将通道维度c移到第一维,现状变为(C,H,W)
# 2. unsqueeze(0):再最前面增加批次维度N,最终形状变为(N,C,H,W)(卷积层要求输入格式)
img_data = torch.tensor(img_data.transpose(2, 0, 1)).unsqueeze(0)
print(img_data.shape)   # 打印转换后的形状,确认是否符合要求# 定义第一个卷积层
conv1 = nn.Conv2d(4,  # 输入通道数:4对应RGBA四通道图像16,  # 输出通道数,相当于生成16个特征图kernel_size=3,  # 卷积核大小,这里相当于是3*3(或者用(3,3)表示)stride=1,   # 步长
)
# 经过conv1后,输出形状预期为(1,16,499,498)计算方式:H/W = 原尺寸 - 卷积核大小 + 1# 定义第二个卷积层
conv2 = nn.Conv2d(16,  # 输入通道数:16 与上一层输出通道数一致32,  # 输出通道数(生成32个特征)kernel_size=3,  # 卷积核大小stride=1,
)
# 经过conv2后,输出形状预期为(1,32,497,496)# 定义第三个卷积层
conv3 = nn.Conv2d(32, 2, kernel_size=3, stride=1,
)
# 输出形状预期为(1,2,495,494)# 执行卷积操作流程
out = conv1(img_data)   # 第一层卷积处理
out = conv2(out)        # 第一层卷积处理
out = conv3(out)        # 第一层卷积处理# 提取输出结果中的特定特征图
# squeeze(0):去除批次中的维度N,形状从(1,2,495,494)变为(2,495,494)
# [1]:取第二个通道的特征图(索引从0开始)
out_data = out.squeeze(0)[1]
print(out_data.shape)# 以下为注释掉的图像显示代码(功能说明)
# out = out.squeeze(0).detach().numpy().transpose(1, 2, 0)
# 解释:
# 1. squeeze(0):去除批次维度
# 2. detach().numpy():将PyTorch张量转换为NumPy数组(脱离计算图)
# 3. transpose(1, 2, 0):将形状从(C, H, W)转换为(H, W, C)(适应matplotlib显示格式)
#
# plt.imshow(out)  # 显示处理后的图像
# plt.show()       # 展示图像窗口

示例2:

import torch
import torch.nn as nndef test01(input):# 参数量:128*3*3*3conv = nn.Conv2d(in_channels=3,  # 输入通道数out_channels=128,  # 输出多少个通道(特征图)卷积核个数,每个卷积核提取图片的某一个特征kernel_size=3,  # 卷积大小stride=1,   # 步长bias=True   # 是否添加偏置)# 获取卷积层的所有参数(名称+参数张量)# named_parameters()返回一个迭代器,包含参数名称和对应的参数值name_par = conv.named_parameters()# 遍历打印每个参数的名称和形状for name, param in name_par:print(name, param.shape)# 使用定义的卷积层处理输入数据output = conv(input)return outputif __name__ == '__main__':# 创建随机输入数据:形状为(5, 3, 224, 224)# 各维度含义:(batch_size=5, in_channels=3, height=224, width=224)# torch.randn()生成符合标准正态分布的随机数input_data = torch.randn(5, 3, 224, 224)# 调用test01函数,获取卷积处理后的输出output = test01(input_data)# 打印输出张量的形状print(output.shape)

总结

        卷积是 CNN 的核心操作,通过滑动卷积核提取局部特征,结合填充、步长控制特征图大小,利用多通道和多卷积核捕捉丰富特征,同时通过参数共享降低计算成本,是处理图像等空间数据的高效工具。

http://www.xdnf.cn/news/17721.html

相关文章:

  • 利用 Java 爬虫按图搜索 1688 商品(拍立淘)实战指南
  • 高效TypeScript开发:VSCode终极配置指南
  • Varjo XR虚拟现实军用车辆驾驶与操作培训
  • 【MATLAB代码】滑动窗口均值滤波、中值滤波、最小值/最大值滤波对比。订阅专栏后可查看完整代码
  • OpenCV中对图像进行平滑处理的4种方式
  • 《多级缓存架构设计与实现全解析》
  • 【跨越 6G 安全、防御与智能协作:从APT检测到多模态通信再到AI代理语言革命】
  • 机器视觉的磁芯定位贴合应用
  • GraphRAG查询(Query)流程实现原理分析
  • Java+Vue构建的MES信息管理系统,含完整源码,功能涵盖生产跟踪、质量管控等,助力企业实现精细化、智能化生产管理
  • 【16-softmax回归】
  • AI 赋能的软件工程全生命周期应用
  • springboot+vue实现通过poi完成excel
  • Postman 平替 技术解析:架构优势与实战指南
  • 观察者模式(C++)
  • 【Leetcode hot 100】76.最小覆盖字串
  • 【HarmonyOS】Window11家庭中文版开启鸿蒙模拟器失败提示未开启Hyoer-V
  • SwiftUI 页面弹窗操作
  • 用飞算JavaAI一键生成电商平台项目:从需求到落地的高效实践
  • 使用免费API开发口播数字人
  • [机器学习]07-基于多层感知机的鸢尾花数据集分类
  • c++中的Lambda表达式详解
  • Java基础07——基本运算符(本文为个人学习笔记,内容整理自哔哩哔哩UP主【遇见狂神说】的公开课程。 > 所有知识点归属原作者,仅作非商业用途分享)
  • k8s+isulad 网络问题
  • 如何使用 AI 大语言模型解决生活中的实际小事情?
  • 【P81 10-7】OpenCV Python【实战项目】——车辆识别、车流统计(图像/视频加载、图像运算与处理、形态学、轮廓查找、车辆统计及显示)
  • 网络协议序列化工具Protobuf
  • 4.1vue3的setup()
  • 2019 GPT2原文 Language Models are Unsupervised Multitask Learners - Reading Notes
  • Kotlin Data Classes 快速上手