当前位置: 首页 > ops >正文

VAE学习笔记

模型结构:

(m1,m2,m3)是数据经过encoder 得到的编码 

(σ1,σ2,σ3)是控制噪音干扰程度的编码,就是为随机噪音码(e1,e2,e3)分配权重

损失函数2:如果没有对σi 的限制 生成的图片会希望噪音对自身生成图片的干扰越小,于是分配给噪音的权重越小,这样只需要将(σ1,σ2,σ3)赋为接近负无穷大的值就好了,直观上也能看出来在σi=0处取最小

VAE原理:

首先VAE认为 所有数据都是由某个隐藏变量生成的 学会了这个隐藏变量的分布 就可以生成数据。

关键步骤:

Encoder:把输入数据压缩成隐藏变量的分布参数(均值和方差),直接输出固定值会导致生成能力变差 输出分布可以随机采样增加多样性。

重参数化技巧:解决直接采样不可导问题 改用以下方式 。

                                z = μ + σ * ε, 其中 ε ~ N(0, 1)

Decoder:把隐藏变量 z 还原成数据(如生成新图片)。

损失函数:

        重构损失以及KL散度,KL散度主要是限制σ不要跑偏,保证生成多样性。

基础代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torch.nn.functional as F
from torchvision.utils import save_imageclass VAE(nn.Module):def __init__(self, input_size, latent_size):super(VAE, self).__init__()#编码器层self.fc1 = nn.Linear(input_size, 512)self.fc2 = nn.Linear(512, latent_size)self.fc3 = nn.Linear(512, latent_size)#解码器层self.fc4 = nn.Linear(latent_size, 512)self.fc5 = nn.Linear(512, input_size)def encode(self, x):x = F.relu(self.fc1(x)) #编码器的隐藏表示mu = self.fc2(x)logvar = self.fc3(x)return mu, logvardef reparameterize(self, mu, logvar):std = torch.exp(0.5*logvar)eps = torch.randn_like(std)return mu + eps*stddef decode(self, z):z = F.relu(self.fc4(z)) #将潜在变量Z解码为重构图像return torch.sigmoid(self.fc5(z)) #将隐藏表示映射回输入图像大小 用sigmoid激活 产生重构图像def forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)out = self.decode(z)return out , mu, logvardef loss_function(recon_x, x, mu, logvar):MSE = F.mse_loss(recon_x, x.view(-1,input_size), reduction='sum')KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return MSE + KLDif __name__ == '__main__':batch_size = 64epochs = 50sample_interval = 10learning_rate = 1e-3input_size = 784latent_size = 256device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])train_dateset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dateset, batch_size=batch_size, shuffle=True)model = VAE(input_size, latent_size).to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(epochs):model.train()train_loss = 0for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)data = data.view(-1,input_size)predict ,mu, logvar = model(data)loss = loss_function(predict, data, mu, logvar)train_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()train_loss =train_loss / len(train_loader)print('Epoch [{}/{}], Loss: {:.2f}]'.format(epoch + 1, epochs, train_loss))if (epoch+1) % sample_interval == 0:torch.save(model.state_dict(), f'./VAE{epoch+1}.pth')model.eval()with torch.no_grad():pic_num=10sample = torch.randn(pic_num, latent_size).to(device)sample_img = model.decode(sample)save_image(sample_img.view(pic_num,1,28,28), './sample'+str(pic_num)+'.png' , nrow = int(pic_num/2))

http://www.xdnf.cn/news/16986.html

相关文章:

  • Visual Studio Code的下载,安装
  • 机器学习(11):岭回归Ridge
  • iOS混淆工具有哪些?功能测试与质量保障兼顾的混淆策略
  • OpenLayers 入门指南【五】:Map 容器
  • C语言的数组与字符串
  • 力扣热题100——双指针
  • Hadoop MapReduce 3.3.4 讲解~
  • SpringBoot自动装配原理
  • 36.【.NET8 实战--孢子记账--从单体到微服务--转向微服务】--缓存Token
  • 编程算法:技术创新与业务增长的核心驱动力
  • IDA9.1使用技巧(安装、中文字符串显示、IDA MCP服务器详细部署和MCP API函数修改开发经验)
  • 电商直播流量爆发式增长,华为云分布式流量治理与算力调度服务的应用场景剖析
  • 构建属于自己的第一个 MCP 服务器:初学者教程
  • 从零认识OpenFlow
  • 学习游戏制作记录(角色属性和状态脚本)8.4
  • 【Linux指南】软件安装全解析:从源码到包管理器的进阶之路
  • AI鉴伪技术鉴赏:“看不见”的伪造痕迹如何被AI识破
  • Java项目:基于SSM框架实现的电子病历管理系统【ssm+B/S架构+源码+数据库+毕业论文+远程部署】
  • Git如何同步本地与远程仓库并解决冲突
  • 【iOS】渲染原理离屏渲染
  • 打造个人数字图书馆:LeaNote+cpolar如何成为你的私有化知识中枢?
  • 时序数据库如何高效处理海量数据
  • Spring P1 | 创建你的第一个Spring MVC项目(IDEA图文详解版,社区版专业版都有~)
  • 【数据库】使用Sql Server创建索引优化查询速度,一般2万多数据后,通过非索引时间字段排序查询出现超时情况
  • Anthropic 禁止 OpenAI 访问 Claude API:商业竞争与行业规范的冲突
  • 接口重试方案,使用网络工具的内置重试机制,并发框架异步重试,Spring Retry,消息队列重试,Feign调用重试,监控与报警,避坑指南
  • Linux 系统启动原理
  • mac 技巧
  • Postman 四种请求体格式全解析:区别、用法及 Spring Boot 接收指南
  • 手搓TCP服务器实现基础IO