当前位置: 首页 > backend >正文

深度学习--卷积神经网络数据增强

文章目录

  • 一、数据增强
    • 1、什么是数据增强?
    • 2、为什么需要数据增强?
  • 二、常见的数据增强方法
    • 1、图像旋转
    • 2、图像翻转
    • 3、图像缩放
    • 4、图像平移
    • 5、图像剪切
    • 6、图像亮度、对比度、饱和度调整
    • 7、噪声添加
    • 8、随机扰动
  • 三、代码实现
    • 1、预处理
    • 2、使用数据增强增加训练数据
    • 3、代码中文件
      • 1)train.txt
      • 2)test.txt
      • 3)20个图像输出
  • 总结


一、数据增强

1、什么是数据增强?

**数据增强(Data Augmentation):**缓解深度学习中数据不足的场景,在图像领域首先得到广泛使用,进而延伸到 NLP 领域,并在许多任务上取得效果。一个主要的方向是增加训练数据的多样性,从而提高模型泛化能力。本文将深入探讨数据增强的原理、常用方法及其在CNN中的应用实践。

2、为什么需要数据增强?

CNN模型通过多层卷积和池化操作提取图像特征,但模型的参数量庞大,容易在小规模数据集上过拟合(即模型过度记忆训练数据,无法泛化到新样本)。数据增强通过对原始数据进行合理的变换,生成多样化的新数据.

  1. 扩充数据集规模:缓解数据稀缺问题。
  2. 引入多样性:模拟真实场景中可能出现的视角、光照、遮挡等变化。
  3. 提升鲁棒性:迫使模型学习更本质的特征,而非数据中的噪声或无关细节。

二、常见的数据增强方法

在这里插入图片描述

1、图像旋转

随机旋转图像一定角度,模拟不同角度观察物体的情况。

2、图像翻转

随机水平或垂直翻转图像,模拟不同方向观察物体的情况。

3、图像缩放

随机调整图像尺寸,模拟物体距离不同的情况。

4、图像平移

随机平移图像一定距离,模拟物体在不同位置的情况。

5、图像剪切

随机裁剪图像一部分,模拟物体遮挡的情况。

6、图像亮度、对比度、饱和度调整

随机调整图像的亮度、对比度和饱和度,模拟不同光照条件下的情况。

7、噪声添加

随机向图像中添加噪声,模拟真实世界中的噪声干扰。

8、随机扰动

随机对图像进行拉伸、扭曲等几何变换,模拟物体形状的变化。


三、代码实现

1、预处理

import torch
from torch.utils.data import DataLoader,Dataset  # 导入打包加载库,Dataset表示数据集的抽象概念,可以被自定义的数据集继承和实现
import numpy as np
from PIL import Image
from torchvision import transformsdata_transforms = {'train':    # 训练集  也可以使用PIL库  smote 训练集transforms.Compose([  # transforms.Compose用于将多个图像预处理操作整合在一起transforms.Resize([300,300]),   # 使图像变换大小transforms.RandomRotation(45),   # 随机旋转,-42到45度之间随机选transforms.CenterCrop(256),    # 从中心开始裁剪[256.256]transforms.RandomHorizontalFlip(p=0.5),  # 随机水平旋转,随机概率为0.5transforms.RandomVerticalFlip(p=0.5),  # 随机垂直旋转,随机概率0.5transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),   # 随机改变图像参数,参数分别表示 亮度、对比度、饱和度、色温transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),   # 将PIL图像或NumPy ndarray转换为tensor类型,并将像素值的范围从[0, 255]缩放到[0.0, 1.0],默认把通道维度放在前面transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])  # 给定均值和标准差对图像进行标准化,前者为均值,后者为标准差,三个值表示三通道图像]),'valid':  # 验证集transforms.Compose([   # 整合图像处理的操作transforms.Resize([256,256]),   # 缩放图像尺寸transforms.ToTensor(),   # 转换为torch类型transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])    # 标准化])
} 

2、使用数据增强增加训练数据

class food_dataset(Dataset):   # food_dataset是自己创建的类名称,继承Dataset类def __init__(self,file_path,transform=None):   # 类的初始化,解析数据文件txt,file_path表示文件路径,transform可选的图像转换操作self.file_path = file_path  # 将文件地址传入self空间self.imgs = []self.labels = []self.transform = transform  # 将数据增强操作传入self空间with open(self.file_path) as f:  # 打开存放图片地址及其类别的文本文件train.txt,samples = [x.strip().split(' ') for x in f.readlines()]   # 遍历文件里的每一条数据,经过处理后存入sample列表,元祖的形式存放for img_path,label in samples:  # 遍历列表中的每个元组的每个元素self.imgs.append(img_path)   # 将图像的路径存入img列表self.labels.append(label)     # 将图片类别标签存入label列表
# 初始化:把图片目录加载到self.def __len__(self):    # 类实例化对象后,可以使用len函数测量对象的个数return len(self.imgs)   # 返回数据集中样本的总数def __getitem__(self, idx):   # 关键,可通过索引idx的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx])   # 使用PIL库中的用法Image打开并识别图像,还不是tensorif self.transform:    # 判断是否有图像转换操作,上述定义默认为None,有则将pil图像数据转换为tensor类型image = self.transform(image)   # 图像处理为256*256,转换为tenorlabel = self.labels[idx]   # label还不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))  # 首先指定标签类型为int型,然后将其转换为numpy数组类型,然后再使用torch.from_numpy转换为torch类型return image,label  # 返回处理完的图片和标签# 次数导入上述定义函数中的参数
training_data = food_dataset(file_path = './train.txt',transform = data_transforms['train'])
test_data = food_dataset(file_path ='./test.txt',transform = data_transforms['valid'])
# 判断当前使用的是cpu好事gpu
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

注意:做了数据增强不代表模型训练效果一定会变好,大概率会变好。

3、代码中文件

1)train.txt

在这里插入图片描述

2)test.txt

在这里插入图片描述

3)20个图像输出

在这里插入图片描述
在这里插入图片描述


总结

数据增强是CNN训练中简单却高效的“免费午餐”,通过模拟真实世界的数据多样性,显著提升模型的泛化能力。随着AutoML技术的发展,自动化、自适应增强策略正成为新的趋势。然而,数据增强并非万能,其效果需结合实际任务需求和数据特性进行验证。最终,合理的数据增强策略应服务于一个核心目标:让模型学会关注本质特征,而非记住数据。

http://www.xdnf.cn/news/1255.html

相关文章:

  • TP(张量并行)和EP(专家并行)的区别
  • C++学习之游戏服务器开发十二nginx和http
  • 从信息泄露到内网控制
  • STM32外部中断与外设中断区别
  • 数据结构——队列
  • 华为交换机命令笔记
  • 【springsecurity oauth2授权中心】将硬编码的参数提出来放到 application.yml 里 P3
  • C++23 中 static_assert 和 if constexpr 的窄化布尔转换
  • Agent智能体ReAct机制深度解读:推理与行动的完美闭环
  • 实战华为1:1方式1 to 2 VLAN映射
  • hbuilderx云打包生成的ipa文件如何上架
  • 发送百度地图的定位
  • 7.6 GitHub Sentinel后端API实战:FastAPI高效集成与性能优化全解析
  • OpenCV中的透视变换方法详解
  • 【AI模型学习】Swin Transformer——优雅的模型
  • 【含文档+PPT+源码】基于微信小程序的健康饮食食谱推荐平台的设计与实现
  • 【微知】git reset --soft --hard以及不加的区别?
  • 入住刚装修好的新房,房间隔音太差应该怎么办?
  • Unity 带碰撞的粒子效果
  • OpenVINO教程(三):使用NNCF进行模型量化加速
  • MATLAB Coder 应用:转换 MATLAB 代码至 C/C++ | 实践步骤与问题解决
  • 【Pandas】pandas DataFrame truediv
  • 【程序员 NLP 入门】词嵌入 - 上下文中的窗口大小是什么意思? (★小白必会版★)
  • RESTful API 设计原则
  • 深度学习基石:神经网络核心知识全解析(一)
  • Curl用法解析
  • 前端频繁调用后端接口问题思考
  • 2025年4月22日(平滑)
  • 【Python笔记 03 】运算符
  • n8n更新1.87后界面报错Connection lost解决