当前位置: 首页 > web >正文

Pytorch超分辨率模型实现与详细解释

下面我提供一个完整Pytorch超分辨率模型实现每一行代码进行详细解释包括所有引用头文件

import torch #导入pytorch用于构建训练神经网络主要框架

import torch.nn as nn #导入pytorch神经网络模块-包含各种神经网络函数

import torch.nn.functional as F #导入Pytorch神经网络函数模块--包含激活函数损失函数

import torch.utils.data import DataLoader #导入pytorch数据加载工具用于创建管理数据加载

from torchvision import datasets, transforms #导入torchvision数据集和变换模块--提供常用数据图像预处理方法

import matplotlib.pyplot as plt #导入matplotlib pyplot模块-用于绘制图表可视化结果

import numpy as np #导入numpy用于数值计算特别是在处理图像数据

import os #导入操作系统接口模块用于处理文件目录路径

import time #导入时间模块--用于测量训练时间

设置设备GPU如果可用否则CPU

torch.cuda.is_avaiable() 检查当前系统是否可用CUDA GPU

#如果使用GPU加速计算否则使用CPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1 定义ESPCN模型

class ESPCN(nn.Module):

#初始化方法定义模型结构

def __init__(self, upscale_factor, num_channels=1):

初始化ESPCN模型

参数

upscale_factor 放大倍数

num_channels 输入图像通道数默认为1(灰度图)

#调用nn.Module 初始化方法

super(ESPCN, self).__init__()

#第一个卷积层提取特征

nn.Conv2d: 2D卷积层用于处理图像数据

参数输入通道数输出通道数卷积核大小填充大小

这里使用5x5卷积核填充2保持空间尺寸不变

self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=5, padding=2)

#第二卷积进一步处理特征

输入64通道输出32通道3x3卷积核填充1保持尺寸

self.conv2 = nn.Conv2d(64,32,kernel_size=3,padding=1)

#最后一个卷积层生成放大特征图

输出通道数num_channels (upscale_factor *2)

这是因为我们像素pixel_shuffle 来提升分辨率

self.conv3 = nn.Conv2d(32, num_channels (upscale_factor *2), kernel_szie=3, padding=1)

#像素操作子像素卷积层

pixelshuffle 形状 C * r^2, H, W张量

重新排列

self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

#定义向前传播过程描述数据如何通过网络各层

def forward(self, x):

向前传播

参数

x 输入第分辨率图像形状为(batch_size, num_channels, height, width)

返回 分辨率图像形状为(batch_szie, num_channels, height upscale_factor, width upscale_factor)

#第一层卷积使用tanh激活函数

torch.tanh 正切激活函数压缩(-1,1)范围

x = torch.tanh(self.conv1(x))

#第二层卷积后使用tanh激活函数

x = torch.tanh(self.conv2(x))

#第三层卷积

x = self.conv3(x)

#应用像素操作通道维度转换为空间维度

x = self.pixel_shuffle(x)

#使用sigmoid激活函数压缩到(0,1)范围

#这是因为图像像素值通常0-1之间

x = torch.sigmoid(x)

return x

2 准备数据

def prepare_data(batch_szie, upscale_factor, dataset_name='MNIST'):

准备训练预测数据

参数

batch_szie 批处理大小

upscale_factor 放大倍数

dataset_name 数据集名称默认为MNIST

返回

训练测试数据加载

数据转换管道

transforms.Compose 多个变换组合在一起

transform = transforms.Compose([

#transforms.ToTensor PIL图像或者numpy数组转换Pytorch张量

#同时 像素值[0.255]缩放[0,1]范围

transforms.ToTensor(),

#transforms.Normalize 张量进行标准化

#参数均值标准差这里标准化[-1,1]范围

transforms.Normalize(0.5, (0.5,))

])

#根据数据集名称选择不同数据集

if dataset_name == 'MNIST':

#下载加载MNIST训练数据集

#MNIST是一个手写数字数据集包含60 000训练样本10 000测试样本

train_dataset = datasets.MNIST(

root='./data' #数据存储路径

train=True, #加载训练集

download=True,#如果数据不存在下载

transform = transform #应用上面定义数据转换

)

#下载并加载MNIST测试数据集

test_dataset = datasets.MNIST(

root='./data',

train=False, #加载测试集

download=True,

transform=transform

)

else:

#可以在这里添加其他数据集支持

raise ValueError("不支持的数据集:{dataset_name}")

#创建训练数据加载起

DataLoader 包装数据集提供批量加载shuffling 功能

train_loader = DataLoader(

train_dataset,

batch_size = batch_size, #每个批次样本数量

shuffle=True, 每个epoch 开始打乱数据顺序

num_works = 2, 使用2子进程加载数据

pin_memory=True #数据固定内存中加速GPU传输

)

#创建测试数据加载

test_loader = DataLoader(

test_dataset,

batch_size=batch_size,

shuffle=False, #测试不需要打乱顺序

num_works=2,

pin_memory=True

)

#返回训练测试数据加载

return train_laoder, test_loader

3 训练函数

def train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor):

训练模型

参数

model:要训练模型

train_loader: 训练数据加载

criterion 损失函数

optimizer: 优化

num_epochs: 训练

upscale_factor: 放大倍数

#设置模型训练模式

#启用dropoutbatch normalization 训练特定行为

model.train()

#记录训练过程损失值

losses = []

#记录训练开始时间

start_time = time.time()

#循环遍历每个epoch

for epoch in range(num_epochs):

#初始化当前epoch损失

epoch_loss = 0

#遍历训练数据加载起每个批次

for batch_idx, (data, target) in enumerate(train_loader):

#数据移动到相应设备 GPU或者CPU

data = data.to(device)

#创建分辨率输入

#首先图像下采样然后上采样原始大小模拟分辨率图像

#F.interpolate: 图像进行采样或者下采样

#scale_factor= 1/upscale_factor 下采样比例

mode = 'bicubic' 使用双三次循环算法

align_corners= False: 差值算法参数

lr_data = F.interpolate(

data,

scale_factor=1/upscale_factor,

mode='bicubic',

align_corners = False

)

#下采样图像采样原始尺寸

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode = 'bicubic',

align_corners=False

)

#清零梯度

pytorch 梯度累加所以在每个批次开始需要清零

optimizer.zero_grad()

#前向传播降低分辨率图像输入模型得到分辨率输出

output = model(lr_data)

#计算损失比较模型输出原始分辨率图像

loss = criterion(output, data)

#反向传播计算梯度

loss.backward()

#更新权重根据梯度调整模型参数

optimizer.step()

#累加当前批次损失

epoch_loss += loss.item()

#计算当前epoch平均损失

losses.append(avg_loss)

#打印训练进度

if (epoch + 1) % 5 == 0:

#计算已用时间

elapsed_time = time.time() - start_time

#打印当前epoch,epoch损失值已用时间

#训练绘制损失曲线

plt.figure(figsize=(10,5))

plt.plot(losses)

plt.title('Training loss over epochs')

plt.xlabel('Epoch')

plt.ylabel('Loss')

plt.grid(True)

#保存损失曲线图像

plt.savefig('training_loss.png')

plt.show()

#打印训练时间

total_time = time.time() - start_time

4 测试函数

#定义模型测试函数

def test_model(model, test_loader, upscale_factor, num_examples=5):

测试模型显示结果

参数

model: 要测试模型

test_loader: 测试数据加载器

upscale_factor: 放大倍数

num_examples: 显示示例数量

#设置模型评估模式

#禁用dropoutbatch normalization训练特定行为

model.eval()

#初始化示例计数器

examples_shown = 0

#不计算梯度节省内存计算资源

with torch.no_grad():

#遍历测试数据加载器

for i, (data, target) in enumerate(test_loader):

#如果已经显示了足够示例退出循环

if examples_shown >= num_examples:

break

#数据移动到相应设备

data = data.to(device)

#创建分辨率输入(与训练时间相同的方法)

lr_data = F.interpolate(

data,

scale_factor = 1/upscale_factor,

model='bicubic',

align_corners = False

)

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode = 'bicubic',

align_corners = False

)

#生成分辨率图像

hr_output = model(lr_data)

#图像CPU转换为numpy数组以便显示

lr_image = lr_data[0].cpu().sequeeze().numpy()

hr_image = hr_output[0.cpu().sequeeze().numpy()

original_image = data[0].cpu().squeeze().numpy()

#显示结果

plt.figure(figsize=(12,4))

#显示分辨率输入图像

plt.subplot(1,3,1)

plt.imshow(lr_image, cmap='gray')

plt.title('Low Resolution Input')

plt.axis('off')

#显示分辨率输出图像

plt.subplot(1,3,2)

plt.imshow(hr_iamge, cmap='gray')

plt.title('Super Resolution Output')

plt.axis('off')

#显示原始高分辨率图像

plt.subplot(1,3,3)

plt.imshow(original_image, cmap='gray')

plt.title('Original high Resolution')

plt.axis('off')

#保存对比图像

plt.savefig(f'comparsion_example_{examples_shown+1}.png')

plt.show()

#增加示例计数器

examples_shown += 1

5 计算PSNR指标函数

def calculate_psnr(model, test_loader, upscale_factor):

计算模型峰值信噪比PSNR

参数model评估模型

test_loader :测试数据加载器

upscale_factor 放大倍数

返回平均PSNR

#设置模型评估模式

moel.eval()

#初始化PSNR总和样本计数

total_psnr=0.0

total_samples=0

#不计算梯度

with torch.no_grad():

#遍历测试数据加载器

for data, _ in test_loader:

#数据移动到相应设备

data = data.to(device)

#创建分辨率输入

lr_data = F.interpolate(

data,

scale_factor=1/upscale_factor,

mode = 'bicubic',

align_corners = False

)

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode='bicubic',

align_corners=False

)

#生成分辨率图像

hr_output=model(lr_data)

#计算每个样本PSNR

for i in range(data.size(0)):

#张量转换numpy数组

original=data[i].cpu().numpy()

reconstructed=hr_output[i].cpu().numpy()

#计算均方差误差(MSE)

mse = np.mean((original - reconstructed) ** 2)

#避免除以

if mse == 0:

psnr = 100 #无穷PSNR这里100

else :

#计算PSNR 20 log10(MAX) - 10 log10(MSE)

#对于[0,1]范围图像 MAX = 1

psnr=20 np.log10(1.0) - 10 np.log10(mse)

#累加PSNR

total_psnr += psnr

total_samples += 1

#计算平均PSNR

avg_psnr = total_psnr / total_samples

return avg_psnr

6 函数

#定义函数组织整个训练和测试流程

def main():

#参数设置

upscale_factor=2 #放大倍数

num_epochs = 20 #训练

batch_szie = 64 #批处理大小

learning_rate = 0.001 学习率

#创建输出目录如果不存在

if not os.path.exists('results'):

os.makedirs('results')

#准备数据

train_loader, test_loader = prepare_data(batch_size, upscale_factor)

#初始化模型

model = ESPCN(upscale_factor=upscale_factor).to(device)

#打印模型结构

print(model)

#计算模型参数数量

total_params = sum(p.numel() for p in model.parameters())

#定义损失函数 - 均方误差损失

criterion = nn.MSELoss()

#定义优化器Adam优化器

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#训练模型

train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor)

#测试模型

test_model(model, test_loader, upscale_factor)

#计算PSNR

calculate_psnr(model, test_loader, upscale_factor)

#保存模型

model_path='results/espcn_model.pth'

torch.save(model.state_dcit(), model_path)

if __name__=="__main__":

main()

头文件解释总结

1 torch:pytorch提供张量操作自动求导功能

2 torch.nn:pytorch神经网络模块包含各种损失函数

3 torch.nn.functional:pytorch函数接口包含激活函数损失函数

4 torch.utils.data pytorch视觉提供常用数据集图像变换

5 matplotlib.pyplot会图库用于可视化结果

6 torchvision pytorch视觉提供常用数据集图像变换

7 numpy 数据计算用于处理数组数据

8 os 操作系统接口用于处理文件目录

9 time时间模块用于测量运动时间

http://www.xdnf.cn/news/19181.html

相关文章:

  • CRYPT32!CryptMsgUpdate函数分析和asn.1 editor nt5inf.cat 的总览信息
  • 机器学习回顾——逻辑回归
  • Consul 操作命令汇总 - Prometheus服务注册
  • 计算机视觉与深度学习 | 视觉里程计技术全景解析:从原理到前沿应用
  • 2024年09月 Python(四级)真题解析#中国电子学会#全国青少年软件编程等级考试
  • 项目一系列-第8章 性能优化Redis基础
  • 星链调查(SOS)线上问卷调查:全流程标准化实践与核心优势深挖
  • 第三届机械工程与先进制造智能化技术研讨会(MEAMIT2025)
  • 【NJU-OS-JYY笔记】操作系统:设计与实现
  • 锂电池充电芯片 XSP30支持PD/QC等多种快充协议支持最大充电电流2A
  • Origin绘制四元相图
  • [Linux]学习笔记系列 -- mm/shrinker.c 内核缓存收缩器(Kernel Cache Shrinker) 响应内存压力的回调机制
  • 深入解析PCIe 6.0拓扑架构:从根复合体到端点的完整连接体系
  • 宜春城区光纤铺设及接口实地调研
  • C5仅支持20MHZ带宽,如果路由器5Gwifi处于40MHZ带宽信道时,会出现配网失败
  • Pytest 插件方法:pytest_runtest_makereport
  • Stream API 讲解
  • Day17_【机器学习—在线数据集 鸢尾花案例】
  • 宜春城区SDH网图分析
  • 漫谈《数字图像处理》之浅析图割分割
  • 从9.4%到13.5%:ICDM2025录取率触底反弹,竞争压力稍缓
  • 新工具-mybatis-flex学习及应用
  • 大模型应用开发笔记(了解篇)
  • 使用 Bright Data Web Scraper API + Python 高效抓取 Glassdoor 数据:从配置到结构化输出全流程实战
  • Vue 项目首屏加载速度优化
  • 阿里云百炼智能体连接云数据库实践(DMS MCP)
  • AI-调查研究-64-机器人 从零构建机械臂:电机、减速器、传感器与控制系统全剖析
  • 深入解析Qt节点编辑器框架:交互逻辑与样式系统(二)
  • 如何使用 Vector 连接 Easysearch
  • cloudflare-ddns