pytorch例子计算两张图相似度
Siamese网络在MNIST数据集上的应用
下面我将详细解释这段代码的每一部分,帮助你理解这个Siamese网络在MNIST数据集上的实现。
#导入必要的库和模块
from __future__ import print_function #兼容Python2和Python3 的print 函数。
import argparse, random, copy #命令行参数解析,随机数生成,对象复制。
import numpy as np #数值计算库
import torch #Pytorch深度学习框架
import torch.nn as nn#神经网络模块
import torch.nn.functional as F #神经网络函数
import torch.optim as optim #优化器
import torchvision #计算机视觉库
from torch.utils.data import Dataset #数据集基类
from torchvision import datasets #预置数据集
from torchvision import transforms as T #数据预处理
from torch.optim.lr_scheduler import StepLR #学习率调度器
Siamese网络模型定义
class SiameseNetwork(nn.Module):
Siamese网络用语图像像素度估计
网络由两个相同的子网络组成,每个处理一个输入
两个子网络的输出被拼接后传入线性层,再通过sigmoid函数输出相似度
def __init__(self):
super(SiameseNetwork, self).__init__()
#获取ResNet模型
self.resnet = torchvision.models.resnet18(weights=None) #不使用训练权重
#修改第一层卷积以适应MNIST的单通道图像
#原始ResNet设计用于3通道RGB图像,MNIST时单通道灰度图
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias = False)
#获取全连接层的输入特征数
self.fc_in)features = self.resnet.fc.in_features
#移除ResNet的最后一层(全连接层)
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
#定义比较两个图像特征的线性层
self.fc = nn.Sequential (
nn.Linear(self.fc_in_features * 2, 256), #拼接后特征维度翻倍
nn.Relu(inplace=True), #激活函数
nn.Linear(256,1), #输出一个相似度分数
)
#初始化权重
self.sigmoid = nn.Sigmoid() #将输出映射到[0,1]范围
#初始化权重
self.resnet.apply(self.init_weights)
self.fc.apply(self.init_weights)
def init_weights(self, m):
初始化权重函数
if isinstance(m, nn.Linear):
torch.nn.init_xavier_uniform_(m.weight) #Xavier均匀初始化
m.bias.data.fill_(0.01) #偏置初始化为0.01
def forward_once(self, x):
单个分枝的前向传播
output = self.resnet(x) #通过ResNet主干
output = output.view(output.size()[0], -1) #展开特征图
return output
def forward(self, input1, input2):
整个网络的前向传播
#截取两个图像的特征
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
#沿特征维度拼接两个特征向量
output = torch.cat((output1, output2), 1)
#通过全连接层
output = self.fc(output)
#通过sigmoid函数输出相似度概率
return output
自定义数据集类
class APP_MATCHER(Dataset):
自定义数据集类,用于生成成对的MNIST图像
def __init__(self, root, train, download = False):
super(APP_MATCHER, self).__init__()
#加载MNIST数据集
self.dataset = datasets.MNIST(root, train = tarin, download= download)
#为图像添加通道维度(N,1,28,28)
self.data = self.dataset.data.unsqueeze(1).clone()
#按类别分组示例
self.group_examples()
def group_examples(self):
按类别分组示例,便于后续采样
#获取所有标签
np_arr = np.array(self.data.set.targets.clone(), dtype=None, copy=None)
#创建类别到索引的映射
self.grouped_examples = {}
for i in range(0, 10) #0 ~9一共10个类别
self.grouped_examples[i] = np.wherre((np_arr==i))[0] #找出该类别的所有索引
def __len__(self):
返回数据集大小
retunr self.data.shape[0]
def __getitem__(self, index):
获取一对图像的标签
偶数索引,相同类别的不同图像 正样本
奇数索引,不同类别的图像,负样本
#随机选择一个类别
selected_class = random.randint(0, 9)
#随机选择该类别的第一个图像
random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0] - 1)
index_1 = self.grouped_examples[selected_class][random_index_1]
image_1 = self.data[index_1].clone().float() #转为浮点数
#生成正样本对(相同类别)
if index%2 == 0:
#选择同类别但不同的图像
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0] - 1)
while random_index_2 = random_index_1: #确保不是同一个图像
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0] - 1)
index_2 = self.grouped_examples[selected_class][random_index_2]
image_2 = self.data[index_2].clone().float()
target = torch.tensor(1, dtype=torch.float) #标签为1相似
#生存负样本对,不同类别
other_selected_class = random.ranint(0, 9)
while other_selected_class == selected_class: #确保类别不同
other_selected_class = random.randint(0, 9)
#选择该类别的一个图像
random_index_2 = random.randint(0, seld.grouped_examples[other_selected_class].shape[0] - 1)
index_2 = self.grouped_examples[other_selected_class][random_index_2]
image_2 = self.data[index_2].clone().float()
target = torch.tensor(0, dtype = torch.float) #标签为0 不相似
return image_1, image_2, target
训练和测试函数
def train(args, model, device, train_loader, optimizer, epoch):
训练函数
model.train() #设置为训练模式
criterion = nn.BCELoss() #二元交叉墒损失函数
for batch_idx, (images_1, images_2, targets) in enumerate(train_loader):
#将数据移动到设备 GPU
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
#梯度清零
optimizer.zero_grad()
#前向传播
outputs = model(images_1, images_2).squeeze()
#计算损失
loss = criterion(outputs, targets)
#反向传播
loss.backward()
#参数更新
optimizer.step()
#日志输出
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(images_1), len(train_loader.dataset),
- 100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run: #快速测试模式
break
def test(model, device, test_loader):
测试函数
model.eval() #设置为评估模式
test_loss = 0
correct = 0
criterion = nn.BCELoss() #二元交叉墒损失函数
while torch.no_grad(): 禁用梯度计算
for (images_1, images_2, targets) in test_loader:
#将数据移动到设备
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
#前向传播
outputs = model(images_1, images_2).squeeze()
#累加损失
test_loss += criterion(outputs, targets).sum().item()
#计算准确率,预测值 > 0.5 视为正样本
pred = torch.where(outputs > 0.5, 1, 0)
correct += pred.eq(targets.view_as(pred)).sum().item()
#计算平均损失
test_loss /= len(test_loader.dataset)
#输出测试结果
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
主函数
def main():
parser = argparse.ArgumentParser(description='PyTorch Siamese network Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='输入训练批次大小 (默认: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='输入测试批次大小 (默认: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='训练轮数 (默认: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='学习率 (默认: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='学习率衰减系数 (默认: 0.7)')
parser.add_argument('--no-accel', action='store_true',
help='禁用加速器')
parser.add_argument('--dry-run', action='store_true', default=False,
help='快速测试单次训练')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='随机种子 (默认: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='日志输出间隔批次')
parser.add_argument('--save-model', action='store_true', default=False,
help='保存当前模型')
args = parser.parse_args()
#检查是否使用加速器
use_accel = not args.no_accel add torch.cuda.is_available()
#设置随机种子
torch.manual_seed(args.seed)
#设置计算设备
if use_accel:
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f"使用设备: {device}")
#配置数据加载器参数
train_kwargs = {'batch_szie': args.batch_size}
test_kwargs = {'batch_size':args.test_batch_size}
if use_accel:
accel_kwargs = {'num_workers': 1, #数据加载子进程数
‘pin_memory’: True, #锁页内存,加速数据传输
'shuffle': True} #训练数据洗牌
train_kwargs.update(accel_kwargs)
test_kwargs.update(accel_kwargs)
#创建数据集和数据加载器
train_dataset = APP_MATCHER('../data', train = True, download=True)
test_dataset = APP_MATCHER('../data', tarin=False)
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
#初始化模型,优化器和学习率调度器
model = SiameseNetwork().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) #Adadelta优化器
scheduler = StepLR(optimizer, step_size=1, gamma=args.magga) #雪梨洗衰减
#训练循环
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step() #更新学习率
#保存模型
if args.save_model:
torch.save(model.state_dict(), "siamese_network.pt")
#程序入口
if __name__=='__main__':
main()
关键概念解释
1 SIgmese网络
一种特殊的神经网络结构,包含两个或者多个相同的子网络
子网络共享权重参数
用于学习输入之间的相似度关系
2 MNIST数据集
手写数字数据集
包含60000张训练图像和10000张测试图像
每张图像大小为28x28像素
3 ResNet-18
深度残差网络,包含18层
通过残差连接解决深层网络梯度消失问题
这里用作特征提取器
4 训练过程
生成图像对(相同类别或者不同类别)
计算相似度预测值
使用二元交叉墒损失优化网络参数
5 应用场景
人脸验证
签名验证
产品相似度匹配
异常检测