当前位置: 首页 > ds >正文

ARConv的复现流程

使用环境

Python 3.10.16

torch 2.1.1+cu118

torchvision 0.16.1+cu118

其它按照官方提供代码的requirements.txt安装

GitHub - WangXueyang-uestc/ARConv: Official repo for Adaptive Rectangular Convolution

数据准备

从官方主页下载pancollection数据集PanCollection for Survey Paper

以WV3 Dataset为例,我们下载训练集和测试集
[1] Training Dataset(训练数据集, 5.76GB): [Baidu Cloud]
[2] Testing Dataset(测试数据集, 20 Examples/per class): [ReducedData(H5 Format)] [FullData(H5 Format)]
 

训练

在这里我没有使用官方推荐的运行.sh文件,而是直接去调用trainer.py执行,那么我修改了两个文件以找到模型,主要是相对导入和绝对导入的问题。

ARConv/models/models.py

from ARConv import ARConv -> from .ARConv import ARConv

ARConv/trainer.py

from .models import ARNet -> from models import ARNet

 运行trainer.py进行训练,下面给出仅使用GPU 0进行训练的代码

CUDA_VISIBLE_DEVICES="0" python trainer.py --batch_size 16 --epochs 600 --lr 0.0006 --ckpt 20 --train_set_path ./pansharpening/training_data/train_wv3.h5 --checkpoint_save_path ./workdir/wv3 --hw_range 1 18 --task 'wv3'

测试

训练完毕后,模型权重pth文件被存入设定的文件目录中,经过作者的回复,和自己的补充,我写了两个python脚本getFullmat.py和getReducedmat.py分别用于生成模型输出的文件,在Matlab中进行测试。将其中的checkpoint_path 改为自己pth存放的文件路径即可。

getFullmat.py

import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):data = h5py.File(file_path)lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()return lms, ms, pan# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_OrigScale_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Full/RRNet/results/'# 创建保存目录
os.makedirs(save_dir, exist_ok=True)# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()# 加载测试数据
lms, ms, pan = load_set(test_data_path)# 推理所有图像
with torch.no_grad():print('Running model inference...')for i in range(pan.shape[0]):output = model(pan[i], lms[i], 1000, [1, 18])output = rearrange(output, 'b c h w -> b h w c') * 2047output_np = output[0].cpu().numpy()save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')sio.savemat(save_mat_path, {'sr': output_np})print(f"Saved .mat to {save_mat_path}")

getReducedmat.py

import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):data = h5py.File(file_path)lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()return lms, ms, pan# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Reduced/RRNet/results/'# 创建保存目录
os.makedirs(save_dir, exist_ok=True)# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()# 加载测试数据
lms, ms, pan = load_set(test_data_path)# 推理所有图像
with torch.no_grad():print('Running model inference...')for i in range(pan.shape[0]):output = model(pan[i], lms[i], 1000, [1, 18])output = rearrange(output, 'b c h w -> b h w c') * 2047output_np = output[0].cpu().numpy()save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')sio.savemat(save_mat_path, {'sr': output_np})print(f"Saved .mat to {save_mat_path}")

 之后将要2_DL_Result放入ARConv\MetricCode中

修改 Demo1_Reduced_Resolution_MultiExm_wv3.m 和Demo2_Full_Resolution_multi_wv3.m中的file_test路径,改为存放测试集的文件即可。

我分别修改为了

Demo1_Reduced_Resolution_MultiExm_wv3.m :

opts.file = 'test_wv3_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');

Demo2_Full_Resolution_multi_wv3.m:

opts.file = 'test_wv3_OrigScale_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');

并在路径中放入了两个测试集文件

test_wv3_multiExm1.h5和test_wv3_OrigScale_multiExm1.h5

之后运行测试即可成功完成测试

http://www.xdnf.cn/news/3677.html

相关文章:

  • btrace2.0使用方法
  • 基于FastApi实现本地部署DeepSeek-R1-Distill-Qwen与流式输出
  • 文章四《深度学习核心概念与框架入门》
  • 读书记:《认知红利》
  • 云盘系统设计
  • Vue3+Element Plus全套学习笔记-目录大纲
  • UE自动索敌插件Target System Component
  • MAAS Anvil - 高可用 MAAS 部署管理工具
  • 纳米AI搜索体验:MCP工具的实际应用测试,撰写报告 / 爬虫小红书效果惊艳
  • Matplotlib核心课程-2
  • JavaWeb学习打卡-Day7-正向代理、反向代理、Nginx
  • 使用AI-01开发板和开源后端服务搭建整套小智服务系统
  • 在运行 Hadoop 作业时,遇到“No such file or directory”,如何在windows里打包在虚拟机里运行
  • 赎金信(简单)
  • 单一职责原则(SRP)
  • 安妮推广导航系统开心版多款主题网址推广赚钱软件推广变现一键统计免授权源码Annie
  • 写了个脚本将pdf转markdown
  • C/C++工程师使用 DeepSeek
  • [面试]SoC验证工程师面试常见问题(三)
  • 2505C++,wmi客户端示例
  • MySQL:联合查询
  • Linux-07-Shell
  • 大模型在终末期肾脏病风险预测与临床方案制定中的应用研究
  • 如何封装一个线程安全、可复用的 HBase 查询模板
  • Encoder-free无编码器多模态大模型EVEv2模型架构、训练方法浅尝
  • Windows 使用set和setx设置环境变量(skywalk3)
  • 2.LED灯的控制和按键检测
  • 【MySQL】事务管理
  • 区块链+IoT:创新场景落地背后的技术攻坚战
  • Python镜像源配置: