pytorch模型画质增强简单实现
使用数据增强技术可以增加数据集中图像的多样性,从而提高模型的性能和泛化能力,主要的图像增强技术包括。
亮度,对比度调节
在开始图像大小的调整之前我们需要导入数据(图像以眼底图像为例)。
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))
# random_crops = [T.RandomCrop(size=size)(orig_img) for size in (832,704, 256)]
colorjitter_img = [T.ColorJitter(brightness=(2,2), contrast=(0.5,0.5), saturation=(0.5,0.5))(orig_img)]
plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)
ax2 = plt.subplot(122)
ax2.set_title('colorjitter_img')
ax2.imshow(np.array(colorjitter_img[0]))
plt.show()
1 导入必要的库
imgae来自pil库,用于处理图像
Path;来自pathlib库,用于处理路径
matplotlib.pyplot 用于绘制图表
numpy 用于数值计算
torch 来自pytorch库
T来自torchvision.transforms 用于图像转换操作
2 定义一个用于保存图片的matplotlib 配置
设置savefig.bbox为tight ,用于自动调整图表的边距,以确保整个图保存在图像中
3 打开图片文件
使用Image.open函数打开指定路径下的图片文件
4 尝试进行随机裁剪操作(注释)
5 对图片进行颜色抖动转换
创建一个包含颜色抖动转换后的图像列表,这里使用T.ColorJitter方法,设置了亮度,对比度,饱和度的抖动范围
6 创建一个包含两个子图的画布
使用plt.figure('resize:128x128')创建一个图表窗口,标题为resize:128x128
创建两个子图
ax1:显示原始图像
ax2: 显示颜色抖动后的图像
7 在第一个子图中显示原始图像
8 在第二个子图中显示颜色抖动后的图像
9 显示整个画布
在第一个子图ax1中显示原始图像
在第二个子图ax2中显示颜色抖动后的图像
使用plt.show 显示整个图表