当前位置: 首页 > news >正文

Pytorch应用 小记 第一回:基于ResNet网络的图像定位

Pytorch应用小记 第一回

本次小记,提供了一份基于ResNet网络的图像定位代码。在本回代码中,实现了ResNet网络定位宠物头像(采用的数据集是Oxford-IIIT Pet Dataset)。除了提供代码外,本小记对代码中不容易理解的内容,也进行了讲解。
本代码的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0,d2l的版本是1.0.3


文章目录

Pytorch应用小记 第一回

一、代码

二、小记

1.代码思路

2. glob.glob('dataset/images/*.jpg')

3 .xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]

4.xml = open(r'{}'.format(track)).read()

5. data_tree = etree.HTML(xml)

6.img_width = int(data_tree.xpath('//size/width/text()')[0])

7.label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))

8.num = np.random.permutation(len(images))

   data_images = np.array(images)[num]

9. imgs_data = np.repeat(imgs_data[:, :, np.newaxis], 3, axis=2)

10.resnet = torchvision.models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)

net_feature = resnet.fc.in_features

11.self.conv_base = nn.Sequential(*list(resnet.children())[:-1])


一、代码

代码如下所示:

import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
from lxml import etree
from matplotlib.patches import Rectangle
import glob
from PIL import Image
from torch.optim import lr_scheduler
from torchvision.models import ResNet101_Weights
import time# 创建输入图像
data_images = glob.glob('dataset/images/*.jpg')
data_xmls = glob.glob('dataset/annotations/xmls/*.xml')
xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]
images = [image for image in data_images ifimage.split('\\')[-1].split('.jpg')[0] in xmls_names]def transform_labels(track):xml = open(r'{}'.format(track)).read()data_tree = etree.HTML(xml)img_width = int(data_tree.xpath('//size/width/text()')[0])img_height = int(data_tree.xpath('//size/height/text()')[0])x_min = int(data_tree.xpath('//bndbox/xmin/text()')[0])y_min = int(data_tree.xpath('//bndbox/ymin/text()')[0])x_max = int(data_tree.xpath('//bndbox/xmax/text()')[0])y_max = int(data_tree.xpath('//bndbox/ymax/text()')[0])return [x_min / img_width, y_min / img_height, x_max / img_width, y_max / img_height]labels = [transform_labels(track) for track in data_xmls]
label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))
num = np.random.permutation(len(images))
data_images = np.array(images)[num]
# 数组或张量的形状调整为二维结构
label_x_min = np.array(label_x_min).astype(np.float32).reshape(-1, 1)[num]
label_y_min = np.array(label_y_min).astype(np.float32).reshape(-1, 1)[num]
label_x_max = np.array(label_x_max).astype(np.float32).reshape(-1, 1)[num]
label_y_max = np.array(label_y_max).astype(np.float32).reshape(-1, 1)[num]
segment = int(len(images) * 0.7)
train_images = data_images[:segment]
x_min_train_label = label_x_min[:segment]
y_min_train_label = label_y_min[:segment]
x_max_train_label = label_x_max[:segment]
y_max_train_label = label_y_max[:segment]test_images = data_images[segment:]
x_min_test_label = label_x_min[segment:]
y_min_test_label = label_y_min[segment:]
x_max_test_label = label_x_max[segment:]
y_max_test_label = label_y_max[segment:]img_scale = 224
transform = transforms.Compose([transforms.Resize((img_scale, img_scale)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 创建Dataset对象
class dataset_Oxford(data.Dataset):def __init__(self, images_in, out1_label, out2_label,out3_label, out4_label, transition):self.images = images_inself.out1_label = out1_labelself.out2_label = out2_labelself.out3_label = out3_labelself.out4_label = out4_labelself.transition = transitiondef __getitem__(self, index):img = self.images[index]out1_label = self.out1_label[index]out2_label = self.out2_label[index]out3_label = self.out3_label[index]out4_label = self.o
http://www.xdnf.cn/news/375139.html

相关文章:

  • LSP里氏替换原则
  • tmux + ttyd 原理
  • FHE 之 面向小白的引导(Bootstrapping)
  • ISP(Image Signal Processor)处理流程及不同域划分
  • 初等数论--莫比乌斯函数
  • STM32硬件I2C驱动OLED屏幕
  • [文献阅读] wav2vec: Unsupervised Pre-training for Speech Recognition
  • 优选算法——队列+BFS
  • Spark的三种部署模式及其特点与区别
  • GitHub 趋势日报 (2025年05月09日)
  • HTTP:十三.HTTP日志
  • 如何解决 PowerShell 显示 “此系统上禁用了脚本运行” 的问题
  • DAMA语境关系图汇总及考前须知
  • 【Linux系统编程】进程属性--进程状态
  • Vision Transformer(ViT)
  • 事务连接池
  • 编写第一个MCP Server之Hello world
  • 【动态导通电阻】软硬开关下GaN器件的动态RDSON
  • 养生:拥抱健康生活的秘诀
  • 深入解析JavaScript变量作用域:var、let、const全攻略
  • React Hooks:从“这什么鬼“到“真香“的奇幻之旅
  • 《基于人工智能的智能客服系统:技术与实践》
  • 二分类问题sigmoid+二元交叉熵损失
  • 微服务的“迷宫” - 我们为何需要服务网格?
  • 数据库故障排查指南:从连接问题和性能优化
  • Docker使用小结
  • 为什么选择 FastAPI、React 和 MongoDB?
  • vue 组件函数式调用实战:以身份验证弹窗为例
  • 计算机大类专业数据结构下半期实验练习题
  • 【基础IO下】磁盘/软硬链接/动静态库