当前位置: 首页 > backend >正文

用PyTorch手写透视变换

Torch,起码是较老版本,没有原生支持可微分的透视变换。为了解决,可以尝试用Torch3D,或者其他3D Torch的库。这里给一个简单的实现。需要注意,非常老的torch不支持。

  1. 构建目标图像中的像素网格坐标;
  2. 使用 ( H^{-1} ) 反向映射目标图像像素至原图坐标;
  3. grid_sample() 在原图中采样这些位置的值(双线性插值);
  4. 利用 PyTorch 的 autograd 系统自动传递梯度。

📥 输入参数

参数类型说明
imageTensor(C,H,W)输入图像,float32 张量,通道优先格式(如 RGB 图为 3×H×W)
matrixTensor(3,3)透视变换矩阵(Homography)
out_hint输出图像高度
out_wint输出图像宽度

📤 输出结果

返回值类型说明
outputTensor(C, out_h, out_w)输出透视变换后的图像张量

先来效果图
在这里插入图片描述

透视变换

透视变换(Homography),将图像按指定的 3×3 矩阵进行几何变换,也就是矩阵乘法。 输出图像大小是固定的,需要我们 将输出图像每个位置“反推”回输入图像中应该采样的位置,这叫做反向采样(inverse mapping)。

获取变换位置映射

在针对图像做各种变换时候,首先都要有一个meshgrid,用于构建像素坐标网格。对于单应性变换、旋转等都是如此。 具体实现用 arange,生成一个从 0 开始到 out_h-1 的连续整数张量。

yy, xx = torch.meshgrid(torch.arange(out_h),torch.arange(out_w),indexing='ij'
)

得到目标图像中每个像素的位置 (x, y),再构建齐次坐标:

grid = torch.stack([xx, yy, ones], dim=0).view(3, -1)  # shape: (3, H*W)

我们要找到“目标图像第 (x,y) 个像素,在源图像的哪个位置采样”,所以要用 反变换H−1H^{-1}H1 把目标图像的位置映射到源图像坐标。

H_inv = torch.inverse(matrix)
sample_coords = H_inv @ grid  # shape: (3, N)

接着,做除以第三行的归一化:

sample_coords = sample_coords[:2] / sample_coords[2:]  # shape: (2, N)

就能得到输出图像中每个点,在输入图像中的实际采样位置(浮点数)是多少。这里还得做个归一化,为了应对 grid_sample 的输入要求

x_norm = (x / (W - 1)) * 2 - 1
y_norm = (y / (H - 1)) * 2 - 1

接下来到了关键步骤,怎么用映射矩阵来执行变换?

grid_sample 函数

grid_sample 是 PyTorch 中的一个重要函数,常用于图像变换、空间变换网络(STN)、透视变换等场景。它通过提供一组采样坐标点,在输入图像上进行双线性插值或最近邻插值

📥 输入参数说明

参数名类型说明
inputTensor (B, C, H_in, W_in)输入图像或特征图,batch 格式
gridTensor (B, H_out, W_out, 2)每个输出像素在输入图像上的采样坐标,最后一维是 (x, y)
modestr,可选(默认 'bilinear'插值模式:'bilinear''nearest'
padding_modestr,可选(默认 'zeros'超出边界时的填充方式:'zeros', 'border', 'reflection'
align_cornersbool(默认 True是否将输入图像角像素映射到 [-1, 1] 的边界点

📌 坐标说明(关键)

  • grid 中的坐标是归一化的,范围是 [-1, 1]
    • (-1, -1) 表示左上角
    • (1, 1) 表示右下角
  • 这适用于所有尺寸的输入图像,PyTorch 会自动映射到实际的像素位置

所以这里要进行:

warped = F.grid_sample(image.unsqueeze(0),        # (1, C, H, W)sample_grid.unsqueeze(0),  # (1, out_h, out_w, 2)mode='bilinear',padding_mode='zeros',align_corners=True
)

结果是你想要的透视变换图像。

汇总

import torch
import torch.nn.functional as Fdef warp_perspective(image, matrix, out_h, out_w):"""image: Tensor (C, H, W)matrix: Tensor (3, 3)return: warped image (C, out_h, out_w)"""device = image.devicedtype = image.dtypeC, H, W = image.shape# 1. 构建目标图像像素网格yy, xx = torch.meshgrid(torch.arange(out_h, device=device, dtype=dtype),torch.arange(out_w, device=device, dtype=dtype),indexing='ij')ones = torch.ones_like(xx)grid = torch.stack([xx, yy, ones], dim=0).view(3, -1)  # (3, H*W)# 2. 将目标像素通过 H^-1 映射回源图像坐标H_inv = torch.inverse(matrix)sample_coords = H_inv @ grid  # (3, N)sample_coords = sample_coords[:2] / sample_coords[2:]  # (2, N)# 3. 归一化坐标到 [-1, 1]x_norm = (sample_coords[0] / (W - 1)) * 2 - 1y_norm = (sample_coords[1] / (H - 1)) * 2 - 1sample_grid = torch.stack([x_norm, y_norm], dim=-1)  # (N, 2)sample_grid = sample_grid.view(out_h, out_w, 2)sample_grid = sample_grid.unsqueeze(0)  # (1, out_h, out_w, 2)# 4. image -> (1, C, H, W)image = image.unsqueeze(0)warped = F.grid_sample(image,sample_grid,mode='bilinear',padding_mode='zeros',align_corners=True)return warped.squeeze(0)  # (C, out_h, out_w)from PIL import Image
from torchvision.transforms.functional import to_tensor
import matplotlib.pyplot as plt# 加载图片
img = Image.open("img").convert("RGB")
img_tensor = to_tensor(img).float().cuda()  # (C, H, W)# 定义 Homography(可以设置为 requires_grad=True)
H = torch.tensor([[1.0, 0.2, -30.0],[0.1, 1.0, -20.0],[0.0005, 0.0003, 1.0]
], dtype=torch.float32, device='cuda')img_tensor.requires_grad_()  # ✅ 启用梯度
H.requires_grad_()           # ✅ 如果你也想对H求导# 调用纯 Python 实现的 warp 函数
out = warp_perspective(img_tensor, H, 300, 300)# 计算 loss 并反向
loss = out.mean()
loss.backward()# 打印梯度信息
print("Image Grad:", img_tensor.grad.shape)
print("Matrix Grad:", H.grad)# 可视化结果
plt.imshow(out.permute(1, 2, 0).detach().cpu().numpy())
plt.axis('off')
plt.title('Warped Image')
plt.show()
http://www.xdnf.cn/news/15532.html

相关文章:

  • 嵌入式学习-PyTorch(5)-day22
  • Towards Low Light Enhancement with RAW Images 论文阅读
  • ASP.NET Core Hosting Bundle
  • Debian 12中利用dpkg命令安装MariaDB 11.8.2
  • C++11迭代器改进:深入理解std::begin、std::end、std::next与std::prev
  • 在 kubernetes 上安装 jenkins
  • 数据结构自学Day7-- 二叉树
  • I3C通信驱动开发注意事项
  • PHP连接MySQL数据库的多种方法及专业级错误处理指南
  • 本地 LLM API Python 项目分步指南
  • Neo4j Python 驱动库完整教程(带输入输出示例)
  • HCIA第三次综合实验:VLAN
  • python实现自动化sql布尔盲注(二分查找)
  • 清理C盘--办法
  • Redis集群搭建(主从、哨兵、读写分离)
  • 26.将 Python 列表拆分为多个小块
  • Kafka 4.0 技术深度解析
  • 尚庭公寓-----day1----逻辑删除功能
  • PHP语法高级篇(三):Cookie与会话
  • 构建 Go 可执行文件镜像 | 探索轻量级 Docker 基础镜像(我应该选择哪个 Docker 镜像?)
  • DGNNet:基于双图神经网络的少样本故障诊断学习模型
  • element plus使用插槽方式自定义el-form-item的label
  • 3D数据:从数据采集到数据表示,再到数据应用
  • 本地电脑安装Dify|内网穿透到公网
  • 【Qt】插件机制详解:从原理到实战
  • 【科研绘图系列】R语言绘制中国地图和散点图以及柱状图
  • ES2023 新特性解析_数组与对象的现代化操作指南
  • 一文厘清楼宇自控系统架构:包含哪些关键子系统及其作用
  • 部署项目将dll放到system32?不可取
  • 基于LAMP环境的校园论坛项目