当前位置: 首页 > news >正文

SCINet 训练代码修改

不多说,放代码

import os
import sys
import time
import glob
import numpy as np
import torch
import utils
from PIL import Image
import logging
import argparse
import torch.utils
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.autograd import Variablefrom model import *
from multi_read_data import MemoryFriendlyLoaderparser = argparse.ArgumentParser("SCI")
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model')
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--epochs', type=int, default=100, help='epochs')
parser.add_argument('--lr', type=float, default=0.0003, help='learning rate')
parser.add_argument('--stage', type=int, default=3, help='epochs')
parser.add_argument('--save', type=str, default='EXP/', help='location of the data corpus')args = parser.parse_args()os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuargs.save = args.save + '/' + 'Train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
model_path = args.save + '/model_epochs/'
os.makedirs(model_path, exist_ok=True)
image_path = args.save + '/image_epochs/'
os.makedirs(image_path, exist_ok=True)log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)logging.info("train file name = %s", os.path.split(__file__))if torch.cuda.is_available():if args.cuda:torch.set_default_tensor_type('torch.cuda.FloatTensor')if not args.cuda:print("WARNING: It looks like you have a CUDA device, but aren't " +"using CUDA.\nRun with --cuda for optimal training speed.")torch.set_default_tensor_type('torch.FloatTensor')
else:torch.set_default_tensor_type('torch.FloatTensor')def save_images(tensor, path):image_numpy = tensor[0].cpu().float().numpy()image_numpy = (np.transpose(image_numpy, (1, 2, 0)))im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))im.save(path, 'png')def main():if not torch.cuda.is_available():logging.info('no gpu device available')sys.exit(1)np.random.seed(args.seed)cudnn.benchmark = Truetorch.manual_seed(args.seed)cudnn.enabled = Truetorch.cuda.manual_seed(args.seed)logging.info('gpu device = %s' % args.gpu)logging.info("args = %s", args)model = Network(stage=args.stage)model.enhance.in_conv.apply(model.weights_init)model.enhance.conv.apply(model.weights_init)model.enhance.out_conv.apply(model.weights_init)model.calibrate.in_conv.apply(model.weights_init)model.calibrate.convs.apply(model.weights_init)model.calibrate.out_conv.apply(model.weights_init)model = model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=3e-4)MB = utils.count_parameters_in_MB(model)logging.info("model size = %f", MB)print(MB)train_low_data_names = '/root/autodl-tmp/our485/low'TrainDataset = MemoryFriendlyLoader(img_dir=train_low_data_names, task='train')test_low_data_names = '/root/autodl-tmp/eval15/low'TestDataset = MemoryFriendlyLoader(img_dir=test_low_data_names, task='test')# 创建 CUDA 随机数生成器g = torch.Generator(device='cuda')g.manual_seed(args.seed)train_queue = torch.utils.data.DataLoader(TrainDataset, batch_size=args.batch_size,pin_memory=True, num_workers=0, shuffle=True, generator=g)test_queue = torch.utils.data.DataLoader(TestDataset, batch_size=1,pin_memory=True, num_workers=0, shuffle=True, generator=g)total_step = 0for epoch in range(args.epochs):model.train()losses = []for batch_idx, (input, _) in enumerate(train_queue):total_step += 1input = Variable(input, requires_grad=False).cuda()optimizer.zero_grad()loss = model._loss(input)loss.backward()nn.utils.clip_grad_norm_(model.parameters(), 5)optimizer.step()losses.append(loss.item())logging.info('train-epoch %03d %03d %f', epoch, batch_idx, loss)logging.info('train-epoch %03d %f', epoch, np.average(losses))utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))if epoch % 50 == 0 and total_step != 0:logging.info('train %03d %f', epoch, loss)model.eval()with torch.no_grad():for _, (input, image_name) in enumerate(test_queue):input = Variable(input, volatile=True).cuda()# image_name = image_name[0].split('\\')[-1].split('.')[0]image_name_str = image_name[0]# 使用上述方法处理 image_name_strimage_name = os.path.basename(image_name_str)  # 这里使用 os.path.basename 方法image_name = image_name.split('.')[0]  # 如果还需要去掉文件扩展名,可以再进行一次分割illu_list, ref_list, input_list, atten = model(input)u_name = '%s.png' % (image_name + '_' + str(epoch))u_path = os.path.join(image_path, u_name)  save_images(ref_list[0], u_path)if __name__ == '__main__':main()

注意事项:1:随机数生成器要做cuda上
2.保存路径要修改好。

http://www.xdnf.cn/news/311473.html

相关文章:

  • Windows系统升级Nodejs版本
  • Pulse Control LSI vs CPU for motion control
  • 基于STM32、HAL库的TSC2007IPWR触摸屏控制器驱动程序设计
  • MD2card + Deepseek 王炸组合 一键制作小红书知识卡片
  • hybird接口
  • Flutter 合并 ‘dot-shorthands‘ 语法糖,Dart 开始支持交叉编译
  • 左顾右盼-第16届蓝桥第5次STEMA测评Scratch真题第2题
  • java每日精进 5.06【框架之功能权限】
  • 永磁同步电机控制算法-反馈线性化直接转矩控制
  • vue项目生产环境中,nginx的配置
  • 在c++中老是碰到string,这是什么意思?
  • AI大模型驱动的智能座舱研发体系重构
  • 【Linux系统篇】:Linux线程同步---条件变量,信号量与CP模型实现
  • Python cv2形态学操作:从基础原理到实战应用
  • 《AI大模型应知应会100篇》第49篇:大模型应用的成本控制策略
  • Python之pip图形化(GUI界面)辅助管理工具
  • 校内周赛题(思维题)
  • 代码随想录算法训练营第60期第二十八天打卡
  • 系统架构师2025年论文《论软件系统架构评估及其应用》
  • TS 泛型
  • 网络的搭建
  • SSTI学习
  • 系统思考:选择大于努力
  • AI Agent(4):Agent核心技术栈
  • VTK|结合qt创建通用按钮控制显隐(边框、坐标轴、点线面)
  • 【原创】批量区分横屏竖屏照片
  • 云计算与大数据进阶 | 25、可扩展系统构建
  • Mybatis-核心源码相关
  • kaggle注册问题
  • 瑞克的CTF