当前位置: 首页 > backend >正文

【课堂笔记】生成对抗网络 Generative Adversarial Network(GAN)

文章目录

  • 问题背景
  • 原理
  • 更新过程
    • 判别器
    • 生成器

问题背景

  一方面,许多机器学习任务需要大量标注数据,但真实数据可能稀缺或昂贵(如医学影像、稀有事件数据)。如何在少量数据中达到一个很好的训练效果是一个很重要的问题。
  另一方面,传统生成模型(如变分自编码器VAE)生成的样本往往模糊或缺乏多样性,难以捕捉真实数据的复杂分布(如高分辨率图像、复杂文本等)。
  生成式对抗网络(GAN)提出了用生成器(Generator)和判别器(Discriminator),通过对抗训练相互竞争来提高性能。这样能够生成与真实数据分布相似的合成数据,用于数据增强;同时通过生成器和判别器的对抗训练,生成器学习到真实数据的概率分布,生成的样本更加逼真、细节丰富。

原理

  GAN由两个神经网络组成:
(1)生成器 G \mathbf{G} G:输入随机噪声 z ∼ p G ( z ) z \sim p_G(z) zpG(z)(通常是正态或均匀分布),输出生成的假数据 G ( z ) \mathbf{G}(z) G(z),试图模仿真实数据分布 p data p_{\text{data}} pdata
(2)判别器 D \mathbf{D} D:输入数据(真实数据 x ∼ p data x \sim p_{\text{data}} xpdata或假数据 p data p_{\text{data}} pdata),输出概率 D ( x ) ∈ [ 0 , 1 ] \mathbf{D}(x) \in [0, 1] D(x)[0,1],表示数据为真实的概率。
  这两个神经网络是对抗性的,生成器 G \mathbf{G} G企图让假数据更逼真,来让 D \mathbf{D} D犯错;而判别器 D \mathbf{D} D试图最大化区分真假数据的准确性。

  基于这个目的,我们构造一个损失函数:
(1)对于真实数据 x ∼ p data x \sim p_{\text{data}} xpdata,我们希望 D ( x ) → 1 \mathbf{D}(x) \rightarrow 1 D(x)1,定义损失为 − log ⁡ D ( x ) -\log\mathbf{D}(x) logD(x)
(2)对于生成数据 G ( z ) ∼ p G \mathbf{G}(z) \sim p_G G(z)pG,我们希望 D ( G ( z ) ) → 0 \mathbf{D}(\mathbf{G}(z))\rightarrow 0 D(G(z))0,定义损失为 − log ⁡ ( 1 − D ( G ( z ) ) ) -\log(1-\mathbf{D}(\mathbf{G}(z))) log(1D(G(z)))
  判别器的目标是最大化正确分类的概率,即最小化以下损失:
L D = − E x ∼ p data [ log ⁡ D ( x ) ] − E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D = - \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] - \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] LD=Expdata[logD(x)]Ezpz[log(1D(G(z)))]
  生成器的目标是欺骗判别器,即最小化以下损失:
L G = E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_G = \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] LG=Ezpz[log(1D(G(z)))]
  结合两者,我们可以写出GAN的整体目标函数:
min ⁡ G max ⁡ D ( E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \min_G \max_D \left(\mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right]\right) GminDmax(Expdata[logD(x)]+Ezpz[log(1D(G(z)))])
  接下来去解决这个目标,为了叙述方便定义记号 V ( N , G ) V(N, G) V(N,G),并改写为积分形式:
V ( D , G ) : = E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] = ∫ x p data ( x ) log ⁡ D ( x ) d x + ∫ x p g ( x ) log ⁡ ( 1 − D ( x ) ) d x = ∫ x f ( D ( x ) ) d x f ( D ( x ) ) : = p data ( x ) log ⁡ D ( x ) + p g ( x ) log ⁡ ( 1 − D ( x ) ) \begin{align*} V(D, G) &:= \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \\ &=\int_x p_{\text{data}}(x) \log D(x) \, dx + \int_x p_g(x) \log (1 - D(x)) \, dx \\ &=\int_x f(D(x))dx \\ f(D(x)) &:= p_{\text{data}}(x) \log D(x) + p_g(x) \log (1 - D(x)) \end{align*} V(D,G)f(D(x)):=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]=xpdata(x)logD(x)dx+xpg(x)log(1D(x))dx=xf(D(x))dx:=pdata(x)logD(x)+pg(x)log(1D(x))
  首先我们要找最大化 V ( D , G ) V(D, G) V(D,G) D ∗ D^* D,于是对 D D D求导:
∂ f ∂ D ( x ) = p data ( x ) D ( x ) − p g ( x ) 1 − D ( x ) = 0 ⇒ D ∗ ( x ) = p data ( x ) p data ( x ) + p g ( x ) \frac{\partial f}{\partial D(x)} = \frac{p_{\text{data}}(x)}{D(x)} - \frac{p_g(x)}{1 - D(x)} = 0 \\ \Rightarrow D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} D(x)f=D(x)pdata(x)1D(x)pg(x)=0D(x)=pdata(x)+pg(x)pdata(x)
  这个结果表面,最有判别器 D ∗ D^* D输出真实数据和生成数据分布的相对概率。
  接下来将 D ∗ D^* D代入:
V ( D ∗ , G ) = ∫ x [ p data ( x ) log ⁡ ( p data ( x ) p data ( x ) + p g ( x ) ) + p g ( x ) log ⁡ ( p g ( x ) p data ( x ) + p g ( x ) ) ] d x V(D^*, G) = \int_x \left[ p_{\text{data}}(x) \log \left( \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} \right) + p_g(x) \log \left( \frac{p_g(x)}{p_{\text{data}}(x) + p_g(x)} \right) \right] dx V(D,G)=x[pdata(x)log(pdata(x)+pg(x)pdata(x))+pg(x)log(pdata(x)+pg(x)pg(x))]dx
  这个式子比较复杂,经过推导可以证明:
V ( D ∗ , G ) = − log ⁡ 4 + 2 ⋅ JS ( p data ∥ p g ) V(D^*, G) = - \log 4 + 2 \cdot \text{JS}(p_{\text{data}} \| p_g) V(D,G)=log4+2JS(pdatapg)
  其中 J S \mathbf{JS} JS是Jensen-Shannon 散度,它与 K L \mathbf{KL} KL散度的关系为:
JS ( p data ∥ p g ) = 1 2 KL ( p data ∥ p data + p g 2 ) + 1 2 KL ( p g ∥ p data + p g 2 ) \text{JS}(p_{\text{data}} \| p_g) = \frac{1}{2} \text{KL} \left( p_{\text{data}} \| \frac{p_{\text{data}} + p_g}{2} \right) + \frac{1}{2} \text{KL} \left( p_g \| \frac{p_{\text{data}} + p_g}{2} \right) JS(pdatapg)=21KL(pdata2pdata+pg)+21KL(pg2pdata+pg)
  这个结果是合理的。当 p g = p d a t a p_g = p_{data} pg=pdata时, J S \mathbf{JS} JS散度为0,此时目标函数达到最小值 − log ⁡ 4 -\log 4 log4 D ∗ ( x ) = 0.5 \mathbf{D}^*(x) = 0.5 D(x)=0.5,将无法区分数据的真假。
  对于生成器 G \mathbf{G} G的优化等价于最小化这个 J S \mathbf{JS} JS散度。

更新过程

  在上述推导中,对随机分布进行了期望积分,但实际操作过程中直接计算上述积分是不可行的,我们会采用蒙特卡洛方法近似期望值,于是下面的 L D L_D LD L G L_G LG是用约等于。
  蒙特卡洛方法:核心是利用随机性和大数定律,通过从分布 p ( x ) p(x) p(x)中采集大量样本点 x 1 , . . . , x n x_1, ..., x_n x1,...,xn,然后计算样本均值来近似期望值:
E [ f ( X ) ] ≈ 1 n ∑ i = 1 n f ( x i ) \mathbb{E}[f(X)] \approx \frac{1}{n} \sum_{i=1}^n f(x_i) E[f(X)]n1i=1nf(xi)

判别器

  在理论分析中,我们得到了最优判别器 D ∗ ( x ) = p data ( x ) p data ( x ) + p g ( x ) \mathbf{D}^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} D(x)=pdata(x)+pg(x)pdata(x),然而我们不知道数据实际分布 p data p_{\text{data}} pdata,通常采用梯度下降等方式来拟合:
(1)从真实数据中采集一批 x 1 , . . . , x m x_1, ..., x_m x1,...,xm,从生成器中生成一批 G ( z 1 ) , . . . , G ( z m ) G(z_1), ..., G(z_m) G(z1),...,G(zm)
(2)使用梯度下降优化损失 L D L_D LD θ D \theta_D θD是神经网络 D \mathbf{D} D的参数:
L D ≈ − 1 m ∑ i = 1 m [ log ⁡ D ( x i ) + log ⁡ ( 1 − D ( G ( z i ) ) ) ] θ D ← θ D + η ⋅ ∇ θ D L D L_D \approx -\frac{1}{m} \sum_{i=1}^m \left[ \log D(x_i) + \log (1 - D(G(z_i))) \right] \\ \theta_D \gets \theta_D + \eta \cdot \nabla_{\theta_D} L_D LDm1i=1m[logD(xi)+log(1D(G(zi)))]θDθD+ηθDLD

生成器

  生成器的训练和判别器交替进行,同样采用梯度下降等方法来拟合:
(1)从生成器中生成一批 G ( z 1 ) , . . . , G ( z m ) G(z_1), ..., G(z_m) G(z1),...,G(zm)
(2)使用当前判别器 D \mathbf{D} D(已部分训练)计算生成器损失的近似:
L G ≈ − 1 m ∑ i = 1 m log ⁡ D ( G ( z i ) ) L_G \approx -\frac{1}{m} \sum_{i=1}^m \log D(G(z_i)) LGm1i=1mlogD(G(zi))
(3)计算梯度并更新参数:
∇ θ G L G ≈ − 1 m ∑ i = 1 m ∇ θ G log ⁡ D ( G ( z i ) ) θ G ← θ G − η ⋅ ∇ θ G L G \nabla_{\theta_G} L_G \approx -\frac{1}{m} \sum_{i=1}^m \nabla_{\theta_G} \log D(G(z_i)) \\ \theta_G \gets \theta_G - \eta \cdot \nabla_{\theta_G} L_G θGLGm1i=1mθGlogD(G(zi))θGθGηθGLG

http://www.xdnf.cn/news/10450.html

相关文章:

  • 任务23:创建天气信息大屏Django项目
  • 【BootLoader】之stm32F407实现bootloader相关问题
  • Python+MongoDb使用手册(精简)
  • python打卡day42
  • 学习日记-day20-6.1
  • 【AI论文】推理语言模型的强化学习熵机制
  • Cocos 打包 APK 兼容环境表(Android API Level 10~15)
  • 从线性代数到线性回归——机器学习视角
  • 获取 HTTP 请求从发送到接收响应所花费的总时间
  • 什么是缺页中断(缺页中断详解)
  • 基于微信小程序的垃圾分类系统
  • 西瓜书第十章——聚类
  • 思科设备网络实验
  • 鸿蒙OSUniApp集成WebAssembly实现高性能计算:从入门到实践#三方框架 #Uniapp
  • 开发指南120-表格(el-table)斑马纹
  • 无法运用pytorch环境、改环境路径、隔离环境
  • Python编程基础(二)| 列表简介
  • 【Redis】笔记|第4节|Redis数据安全性分析
  • 数据类型与推断:TypeScript 的基础
  • wordpress免费主题网站
  • ASP.NET Core SignalR 身份认证集成指南(Identity + JWT)
  • Spring Boot,注解,@ConfigurationProperties
  • 手拆STL
  • 【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 时间事件处理部分)
  • Selenium的底层原理
  • 鸿蒙OSUniApp声纹识别与语音验证:打造安全可靠的跨平台语音应用#三方框架 #Uniapp
  • 第14讲、Odoo 18 实现一个Markdown Widget模块
  • 网络攻防技术一:绪论
  • 如何编写GitLab-CI配置文件
  • 【Linux】Linux文件系统详解