当前位置: 首页 > news >正文

【深度学习基础】:VGG实战篇(图像风格迁移)

文章目录

  • 前言
  • style transfer原理
    • 原理解析
    • 损失函数
  • style transfer代码
    • 效果图
  • fast style transfer 代码
    • 效果图

前言

本篇来带大家看看VGG的实战篇,这次来带大家看看计算机视觉中一个有趣的小任务,图像风格迁移。

可运行代码位于: Style_transfer (可直接下载运行)

我们所进行的图像风格迁移是基于VGG实现的哦,如果对VGG网络还不太了解的,可以先去看看这边文章:【深度学习基础】:VGG原理篇-CSDN博客

image-20250502122534540

style transfer原理

原理解析

图像风格迁移是一种通过将一张图像的风格应用到另一张图像的内容上,从而生成一张新图像的技术。这种技术通常基于深度学习模型,特别是卷积神经网络(CNN)。

实现步骤

  • 首先,我们需要两张图像:内容图像和风格图像。内容图像是我们想要保持内容的图像,而风格图像是我们想要应用到内容图像上的风格的图像。接下来,我们通过使用预训练的卷积神经网络(如VGG19)来提取图像的特征。
  • 在提取特征之后,我们计算内容损失和风格损失。内容损失度量生成图像与内容图像在内容上的差异,通常通过比较两个图像在浅层特征上的差异来计算。风格损失度量生成图像与风格图像在风格上的差异,风格通常通过计算图像特征之间的相关性来表示,即风格图像的特征图之间的gram矩阵。(这个等会会详细讲述)

其实图像风格迁移的核心是一个优化问题,目标是生成一张图像,使得内容损失和风格损失都最小化。我们初始化一张随机生成的图像。然后,我们使用卷积神经网络提取这张图像的特征,并计算内容损失和风格损失。接着,我们计算总损失,并进行反向传播和优化,以更新生成图像的像素值。这个过程会重复进行多次,直到生成图像在内容和风格上都接近目标图像。

损失函数

好,其实大概的原理基本上都清楚了,但是我们来看,我们是该如何去保证图片的内容是我们所选的内容图片,但是风格确实我们所选的风格图片呢?

首先来看内容函数的损失函数,其是这个还是很好理解的,我们所使用的就是平方差损失函数,我们希望我们的生成图片能够和内容图片具有相同的内容,所以最小化两者的差距即可。但是这里有个细节,我们该选择那个层的输出作为我们的标准呢?其实看最上面的图也能看出,我们的网络越深,提取到的内容图片损失的细节就会越多,所以我们会更多的选择浅层的内容图片特征作为标准。

image-20250502125827494

def calculate_content_loss(original_feat, generated_feat) -> torch.Tensor:"""计算内容损失,即生成特征图与标准特征图的规范化误差平方和"""b, c, h, w = original_feat.shapex = 2. * c * h * w  # 规范化系数return torch.sum((generated_feat - original_feat) ** 2) / x

然后我们在来看风格损失函数,其实风格损失函数的关键在于gram矩阵。那么我们回到最初,我们为什么能够提取出来图像的风格特征呢?其实就是依靠Gram矩阵了,其用于捕捉图像中不同通道之间的相关性,这些相关性可以反映出图像的纹理、颜色分布等风格特征。通俗说就是出现尖尖的形状的时候边上应该出现些什么。

image-20250502125816341

然后风格损失函数如下,同样的我们会选择深层的风格图像特征作为我们的标准,因为深层网络提取出来的更多的是抽象的语义信息,能够代表其本质的特征,抽象的特征。

深度学习项目二: 图像的风格迁移和图像的快速风格迁移 (含数据和所需源码)_牛客博客

def gram_matrix(feature):"""计算特征图的格莱姆矩阵,作为风格特征表示:param feature: 输入特征图:return: 格莱姆矩阵"""b, c, h, w = feature.size()feature = feature.view(b, c, h * w)  # 拉平空间维度G = torch.bmm(feature, feature.transpose(1, 2))return Gdef calculate_style_loss(style_feat, generated_feat) -> torch.Tensor:"""计算风格损失,即生成特征图与标准特征图的格拉姆矩阵的规范化误差平方和"""b, c, h, w = style_feat.shapeG = gram_matrix(generated_feat)A = gram_matrix(style_feat)x = 4. * ((h * w) ** 2) * (c ** 2)  # 规范化系数return torch.sum((G - A) ** 2) / x

所以最终的损失函数就是风格损失和内容损失的和,其中会乘上不同的系数。

image-20250502125754226

style transfer代码

这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。

我们通过处理目标图片,其一开始就是一张白噪声图,然后通过网络损失不断修改其像素,使得目标图片内容与内容图片相近,风格与风格图片相近。

image-20250502133002891

import argparse
import os
import torch
import torch.optim as optim
from tqdm import tqdmfrom models import VGG
from utils import make_transform, load_image, save_image, calculate_style_loss, calculate_content_lossdef parse_args():parser = argparse.ArgumentParser(description='Style Transfer=')parser.add_argument('--content_image', type=str, default='./data/content1.jpg', help='Path to the content image')parser.add_argument('--style_image', type=str, default='./data/style2.jpg', help='Path to the style image')parser.add_argument('--output_dir', type=str, default='./output/iterative_style_transfer', help='Output directory')parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')parser.add_argument('--content_weight', type=float, default=1, help='Content weight')parser.add_argument('--style_weight', type=float, default=15, help='Style weight')parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')parser.add_argument('--steps_per_epoch', type=int, default=100, help='Number of steps per epoch')parser.add_argument('--learning_rate', type=float, default=0.03, help='Learning rate')return parser.parse_args()if __name__ == "__main__":args = parse_args()# 检查文件路径assert os.path.exists(args.content_image), f"content image is not exist: {args.content_image}"assert os.path.exists(args.style_image), f"style image is not exist: {args.style_image}"# 内容特征层及loss加权系数content_layers = {'5': 0.5, '10': 0.5}# 风格特征层及loss加权系数style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}# ----------------训练即推理过程----------------device = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = make_transform(args.image_size, normalize=True)content_img = load_image(args.content_image, transform).to(device)style_img = load_image(args.style_image, transform).to(device)# 用小随机噪声初始化图像(避免标准正态分布过大震荡)generated_img = torch.rand_like(content_img).mul(0.1).requires_grad_().to(device)save_image(generated_img, args.output_dir, 'noise_init.jpg')vgg_model = VGG(content_layers, style_layers).to(device).eval()# 关闭梯度追踪,节省显存with torch.no_grad():content_features, _ = vgg_model(content_img)_, style_features = vgg_model(style_img)optimizer = optim.Adam([generated_img], lr=args.learning_rate)for epoch in range(args.epochs):p_bar = tqdm(range(args.steps_per_epoch), desc=f'Epoch {epoch + 1}/{args.epochs}')for step in p_bar:generated_content, generated_style = vgg_model(generated_img)content_loss = sum(args.content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content)for name, gen_content in generated_content.items())style_loss = sum(args.style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style)for name, gen_style in generated_style.items())total_loss = style_loss + content_lossoptimizer.zero_grad()total_loss.backward()# 梯度裁剪(可选但稳健)torch.nn.utils.clip_grad_norm_([generated_img], max_norm=1.0)optimizer.step()p_bar.set_postfix(style_loss=style_loss.item(), content_loss=content_loss.item(), total_loss=total_loss.item())# 保存中间结果save_image(generated_img, args.output_dir, f'generated_epoch_{epoch + 1}.jpg')

效果图

89_8

fast style transfer 代码

刚才我们看了基于迭代的图像风格迁移,那么能不能有个方法,我们训练一个专用于一种风格的图像迁移的网络,其网络训练好之后,我们就不再需要老是输入风格图像,这个训练好的网络可以处理所有输入的图像输出其已经训练好的风格。所以fast style transfer 来了。其不仅可以处理图片,同样可以处理视频哦!

image-20250502132529144

这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。

import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterablefrom PIL import Imageimport torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as npfrom models import VGG, TransNet
from datasets import ImageDataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \denormalizedef parse_args():parser = argparse.ArgumentParser(description='Fast Style Transfer')parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='train',help='Operation mode: train, image, video')parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')parser.add_argument('--style_image', type=str, default='./data/style3.jpg', help='Style image path (training mode only)')# The directory data/train2017/default_class contains several images. Due to the format required by ImageFolder, we need to use a class folder to contain the images, even though the class label is not used.parser.add_argument('--content_dataset', type=str, default='data/train2014', help='Content image dataset path (training mode only)')parser.add_argument('--content_weight', type=float, default=1., help='Content weight (training mode only)')parser.add_argument('--style_weight', type=float, default=15., help='Style weight (training mode only)')parser.add_argument('--model_save_path', type=str, default=f'./checkpoint/style2.pth', help='Model save path (training mode)')parser.add_argument('--pretrained_model_path', type=str, help='Pretrained model load path (training mode only)')parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (training mode)')parser.add_argument('--save_interval', type=int, default=120, help='Save interval (seconds) (training mode)')parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (training mode)')parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='Output directory (training mode)')parser.add_argument('--model_path', type=str, default='./checkpoint/style1.pth', help='Model load path (image and video mode)')parser.add_argument('--input_images_dir', type=str, help='Input images root directory (image mode only)')parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg',help='Output images root directory (image mode only)')parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='Input video path (video mode only)')parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4',help='Output video path (video mode only)')parser.add_argument('--batch_size', type=int, default=4, help='Batch size')return parser.parse_args()def train(model,vgg,lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers,device, transform,image_style,content_dataset_root,save_path, output_dir,save_interval=timedelta(seconds=120)):optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # 对目标网络进行优化dataset = ImageDataset(content_dataset_root, transform=transform)dataloader = DataLoader(dataset, batch_size, shuffle=True)_, style_features = vgg(image_style)p_bar = tqdm(range(epochs))last_save_time = datetime.now() - save_intervalfor epoch in p_bar:running_content_loss, running_style_loss = 0.0, 0.0for i, content_img in enumerate(dataloader):content_img = content_img.to(device)image_generated = model(content_img)  # 只使用内容图像进行风格迁移generated_content, generated_style = vgg(image_generated)style_loss = sum(style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) forname, gen_style in generated_style.items())content_features, _ = vgg(content_img)  # 计算内容图的内容特征content_loss = sum(content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) forname, gen_content in generated_content.items())total_loss = style_loss + content_lossoptimizer.zero_grad()total_loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)  # 梯度裁剪optimizer.step()running_content_loss += content_loss.item()running_style_loss += style_loss.item()p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%',style_loss=f"{style_loss.item():.3f}",content_loss=f"{content_loss.item():.3f}",last_save_time=last_save_time)if datetime.now() - last_save_time > save_interval:last_save_time = datetime.now()#writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i)save_model(model, save_path)  # 'fast_style_transfer.pth'save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg')def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]:images = torch.stack([transform(image) for image in images]).to(device)model.to(device)batch_generated = model(images)batch_generated = denormalize(batch_generated).detach().cpu()batch_generated = [transforms.ToPILImage()(image) for image in batch_generated]return batch_generateddef process_video(video_path, output_path, transform, model, device, batch_size=4):# 打开视频文件cap = cv2.VideoCapture(video_path)output_dir, filename = os.path.split(output_path)if output_dir and not os.path.exists(output_dir):os.makedirs(output_dir)# 获取视频属性frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))fps = cap.get(cv2.CAP_PROP_FPS)total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))# 定义视频编码器和创建 VideoWriter 对象fourcc = cv2.VideoWriter_fourcc(*'mp4v')out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))# 初始化 tqdm 进度条pbar = tqdm(total=total_frames, desc="Processing Video")# 读取视频并批量处理frames = []while True:ret, frame = cap.read()if not ret:breakframes.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))if len(frames) == batch_size:batch_generated = process_images(frames, transform, model, device)for gen_frame in batch_generated:gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)cv2.imshow("output", gen)gen = cv2.resize(gen, (frame_width, frame_height))out.write(gen)if cv2.waitKey(1) & 0xFF == ord('s'):breakframes.clear()pbar.update(batch_size)if frames:batch_generated = process_images(frames, transform, model, device)for gen_frame in batch_generated:out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR))pbar.update(len(frames))print(f'video successfully saved to: {output_path}')pbar.close()cap.release()out.release()if __name__ == '__main__':args = parse_args()# print(args)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# ----------------路径参数----------------# 内容特征层及loss加权系数content_layers = {'5': 0.5, '10': 0.5}  # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性# 风格特征层及loss加权系数style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}  # 使用vgg不同深度的风格特征,生成风格更加层次丰富transform = make_transform(size=args.image_size, normalize=True)  # 图像变换image_style = load_image(args.style_image, transform=transform).to(device)  # 风格图像vgg = VGG(content_layers, style_layers).to(device)  # 特征提取网络,只用来提取特征,不进行训练model = TransNet(input_size=args.image_size).to(device)  # 内容生成网络,用于生成风格图片,进行训练if args.mode != 'train' and getattr(args, 'model_path'):if not os.path.exists(args.model_path):raise FileNotFoundError(f'{args.model_path}不存在!')model.load_state_dict(torch.load(args.model_path))elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path):if not os.path.exists(args.pretrained_model_path):raise FileNotFoundError(f'{args.pretrained_model_path}不存在!')model.load_state_dict(torch.load(args.pretrained_model_path))if args.mode == 'train':# 训练模式# 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight,style_layers,content_layers, device, transform, image_style, args.content_dataset, args.model_save_path,args.output_dir,save_interval=timedelta(seconds=args.save_interval))elif args.mode == 'image':# 使用训练好的风格迁移模型演示批量处理图片if not os.path.exists(args.output_images_dir):os.makedirs(args.output_images_dir)for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'):try:filepath = os.path.join(args.input_images_dir, filename)images_generated = process_images([Image.open(filepath)], transform, model, device)images_generated[0].save(os.path.join(args.output_images_dir, filename))except Exception as e:passelif args.mode == 'video':# 视频处理模式process_video(args.video_input, args.video_output, transform, model, device,batch_size=args.batch_size)else:raise ValueError("未知的运行模式")

效果图

tutieshi_640x480_3s

http://www.xdnf.cn/news/268381.html

相关文章:

  • [Windows] Kazumi番剧采集v1.6.9:支持自定义规则+在线观看+弹幕,跨平台下载
  • ecs网站备份,ecs网站备份的方法
  • 基于YOLOv8的人流量识别分析系统
  • 普通 html 项目引入 tailwindcss
  • 【算法专题九】链表
  • Socket 编程 UDP
  • C++继承基础总结
  • GESP2024年6月认证C++八级( 第三部分编程题(2)空间跳跃)
  • VFS Global 携手 SAP 推动数字化转型
  • Three.js支持模型格式区别、建议
  • <property name=“userDao“ ref=“userDaoBean“/> 这两个的作用和语法
  • Java虚拟线程基础介绍
  • 23.合并k个升序序链表- 力扣(LeetCode)
  • Spring Cloud与Service Mesh集成:Istio服务网格实践
  • 【学习笔记】 强化学习:实用方法论
  • deepseek提供的Red Hat OpenShift Container Platform 4.X巡检手册
  • 深入理解Redis SDS:高性能字符串的终极设计指南
  • 基于Springboot高校网上缴费综合务系统【附源码】
  • CSS元素动画篇:基于当前位置的变换动画(合集篇)
  • 《算法导论(第4版)》阅读笔记:p2-p3
  • Java大师成长计划之第11天:Java Memory Model与Volatile关键字
  • 【Mytais系列】Myatis的设计模式
  • API接口:轻松获取企业联系方式
  • 理解Android Studio IDE工具
  • 虚幻基础:角色朝向
  • MIT6.S081-lab8前置
  • C++ 开发指针问题:E0158 表达式必须为左值或函数指示符
  • UDP 通信详解:`sendto` 和 `recvfrom` 的使用
  • python进阶(1)字符串
  • DeepSeek-Prover-V2-671B:AI在数学定理证明领域的重大突破