当前位置: 首页 > web >正文

人脸图像生成(DCGAN)

- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

深度卷积对抗网络(Deep Convolutional Generative Adversarial Networks)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)两个神经网络组成。DCGAN 结合了卷积神经网络和生成对抗网络的思想,用于生成逼真的图像。

一. 理论基础

1.DCGAN原理

深度卷积对抗网络是生成对抗网络的一种模型改进,其将卷积运算的思想引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果。DCGAN模型有以下特点:

  • 判别器模型使用了卷积步长取代了空间池化,生成器模型中使用了反卷积操作扩大数据维度。
  • 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其他层上都使用了Batch Normalization, 原因是Batch Normalization 可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  • 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  • 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其他层中均使用了ReLU激活函数;在判别器上使用了Leaky ReLU激活函数。

图中所示了一种常见的DCGAN结构。主要包含了一个生成网络G 和一个判别网络 D,生成网络G 负责生成图像,它接受一个随机的噪声z,通过该噪声生成图像,将生成的图像记为G(z),判别网络D 负责判断一张图是否为真实,它的输入是x,代表一张图像,输出D(x)表示x为真实图像的概率。

实际上判别网络D是对数据的来源进行一个判别:究竟这个数据是来自真是的数据分布Pd(x)判别为“1”,还是来自于一个生成网络G所产生的一个数据分布Pg(z)(判别为“0”)。所以在整个训练过程中,生成网络G的目标是生成可以以假乱真的图像G(z),当判别网络D无法区分,即D(G(z))=0.5时,便得到了一个生成网络G用来生产图像扩充数据集。

二.前期准备

1.导入第三方库

import torch,random,os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.autograd import VariablemanualSeed = 999
print("random seed:",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)

2.设置超参数

dataroot = "/content/drive/MyDrive/GAN_Dataset"
batch_size = 128 #训练过程中的批次大小
image_size = 64 #图像的尺寸(宽度和高度)
nz = 100 # z潜在的向量大小(生成器输入的尺寸)
ngf = 64 # 生成器中的特征图大小
ndf = 64 #判别器中的特征图大小
num_epochs = 50 #训练的总论数
lr = 0.0002 #学习率
beta1=0.5 #adam 优化器的beta1超参数

3.导入数据

dataset = dset.ImageFolder(root=dataroot,transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],padding=2,normalize=True).cpu(),(1,2,0)))

三.定义模型

1.初始化权重

def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data,0.0,0.02)elif classname.find('BatchNorm')!=-1:nn.init.normal_(m.weight.data,1.0,0.02)nn.init.constant_(m.bias.data,0)

2.定义生成器

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.main = nn.Sequential(nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),nn.BatchNorm2d(ngf*8),nn.ReLU(True),#输出尺寸:(ngf*8)x4x4nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),nn.BatchNorm2d(ngf*4),nn.ReLU(True),#输出尺寸:(ngf*4)x8x8nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),nn.BatchNorm2d(ngf*2),nn.ReLU(True),#输出尺寸:(ngf*2)x16x16nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),#输出尺寸:(ngf)x32x32nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),nn.Tanh()#输出尺寸:3x64x64)def forward(self,input):return self.main(input)
#创建生成器
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

3.定义鉴别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3,ndf,4,2,1,bias=False),nn.LeakyReLU(0.2,inplace=True),#输出尺寸:(ndf)x32x32nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),#输出尺寸:(ndf*2)x16x16nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),#输出尺寸:(ndf*4)x8x8nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2,inplace=True),#输出尺寸:(ndf*8)x4x4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self,input):return self.main(input)
#创建判别器模型
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)

四:训练模型

1.定义训练参数

criterion = nn.BCELoss()
fixed_noise = torch.randn(64,nz,1,1,device = device)real_label =1.
fake_label =0.optimizerD = optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG = optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))

2.训练模型

下面的训练代码是一个典型的GAN训练循环。在训练过程中,首先更新判别器网络,然后更新生成器网络。在每个epoch的每个batch中,会进行以下操作:

  • 更新判别器网络:通过训练真实图像样本和生成图像样本,最大化判别器的损失。具体步骤如下:

    • 对于真实图像样本,计算判别器对真实图像样本的输出和真实标签之间的损失,然后进行反向传播计算梯度。
    • 对于生成的图像样本,计算判别器对生成图像样本的输出和假标签之间的损失,然后进行反向传播计算梯度。
    • 将真实图像样本的损失和生成图像样本的损失相加得到判别器的总损失,并更新判别器的参数。
  • 更新生成器网络:通过最大化生成器的损失,迫使生成器产生更逼真的图像样本。具体步骤如下:

    • 使用生成器生成一批假图像样本。
    • 将生成图像样本输入判别器,计算判别器对生成图像样本的输出和真实标签之间的损失,并进行反向传播计算生成器的梯度。
    • 更新生成器的参数。
  • 输出训练统计信息:每隔一定的步数,输出当前训练的epoch、batch以及判别器和生成器的损失值等信息。

  • 保存损失值:将生成器和判别器的损失值存储到相应的列表中,以便后续绘图和分析。

  • 检查生成器的性能:每隔一定的步数或者在训练结束时,通过将固定的噪声输入生成器,生成一批图像样本,并保存到img_list列表中。这样可以观察生成器在训练过程中生成的图像质量的变化。

  • 更新迭代次数:每完成一个batch的训练,将迭代次数iters加1。

总体来说,这段代码实现了GAN的训练过程,通过交替更新判别器和生成器的参数,目标是使生成器生成逼真的图像样本,同时判别器能够准确区分真实图像样本和生成图像样本。

img_list =[]
G_losses=[]
D_losses=[]
iters=0
print("start training")for epoch in range(num_epochs):for i,data in enumerate(dataloader,0):netD.zero_grad()real_cpu = data[0].to(device)b_size = real_cpu.size(0)label = torch.full((b_size,),real_label,dtype=torch.float,device=device)output = netD(real_cpu).view(-1)errD_real = criterion(output,label)errD_real.backward()D_x = output.mean().item()#使用生成图像样本训练noise = torch.randn(b_size,nz,1,1,device=device)fake = netG(noise)label.fill_(fake_label)output = netD(fake.detach()).view(-1)errD_fake = criterion(output,label)errD_fake.backward()D_G_z1 = output.mean().item()errD = errD_real + errD_fakeoptimizerD.step()#更新生成器网络netG.zero_grad()label.fill_(real_label)output = netD(fake).view(-1)errG = criterion(output,label)errG.backward()D_G_z2 = output.mean().item()optimizerG.step()if i % 400 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))G_losses.append(errG.item())D_losses.append(errD.item())if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake, padding=2, normalize=True))iters += 1

3.可视化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)HTML(ani.to_jshtml())

real_batch = next(iter(dataloader))
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

http://www.xdnf.cn/news/15277.html

相关文章:

  • Java线程进阶-并发编程
  • python的病例管理系统
  • halcon 求一个tuple的极值点
  • 性能狂飙 Gooxi 8卡5090服务器重新定义高密度算力
  • 深入剖析Spring Bean生命周期:从诞生到消亡的全过程
  • JavaSE——Object
  • Linux驱动基本概念(内核态、用户态、模块、加载、卸载、设备注册、字符设备)
  • DSSA(Domain-Specific Software Architecture)特定领域架构
  • 台球 PCOL:极致物理还原的网页斯诺克引擎(附源码深度解析)
  • Leaflet面试题及答案(21-40)
  • 2025年体育科学与健康大数据国际会议(ICSSHBD 2025)
  • OpenAI 将推 AI Agent 浏览器:挑战 Chrome,重塑上网方式
  • 异构Active DataGuard对于convert参数的错误理解
  • SpringCloud之Feign
  • 从「小公司人事」到「HRBP」:选对工具,比转岗更能解决成长焦虑
  • 十二、k8s工程化管理Helm
  • Linux自动化构建工具(一)
  • pdf拆分
  • 《打破预设的编码逻辑:Ruby元编程的动态方法艺术》
  • LVS负载均衡-DR模式配置
  • 进制转换原理与实现详解
  • 【unity编辑器开发与拓展EditorGUILayoyt和GUILayoyt】
  • RISC-V:开源芯浪潮下的技术突围与职业新赛道 (三)RISC-V架构深度解剖(下)
  • 【八股消消乐】浅尝Kafka性能优化
  • 【面板数据】省级泰尔指数及城乡收入差距测算(1990-2024年)
  • Vue集成MarkDown
  • 开源界迎来重磅核弹!月之暗面开源了自家最新模型 K2
  • UC浏览器PC版自2016年后未再更新不支持vue3
  • Git Submodule 介绍和使用指南
  • 服务器机柜与网络机柜各自的优势