当前位置: 首页 > java >正文

DIT(Diffusion In Transformer)学习笔记

DIT(Diffusion In Transformer)学习笔记


一、概率建模与数学推导

1. 扩散过程的条件概率重参数化

传统扩散模型的条件概率
传统扩散模型(如DDPM)的逆过程定义为:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ t ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t,t), \Sigma_t) pθ(xt1xt)=N(xt1;μθ(xt,t),Σt)
其中均值 μ θ \mu_\theta μθ通过U-Net建模,仅依赖 x t x_t xt和标量时间步 t t t

DIT的条件概率重构
DIT引入Transformer建模时空依赖:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; G θ ( Tokenize ( x t ) + E pos , E time ( t ) ) , σ t 2 I ) p_\theta(x_{t-1}|x_t) = \mathcal{N}\left(x_{t-1}; G_\theta\left( \text{Tokenize}(x_t) + E_{\text{pos}}, E_{\text{time}}(t) \right), \sigma_t^2 I \right) pθ(xt1xt)=N(xt1;Gθ(Tokenize(xt)+Epos,Etime(t)),σt2I)

  • Tokenize:将图像分割为 N × N N \times N N×N的patch(如16×16),生成序列长度 L = ( H / N ) × ( W / N ) L = (H/N) \times (W/N) L=(H/N)×(W/N)的Token序列 Tokenize ( x t ) ∈ R L × d \text{Tokenize}(x_t) \in \mathbb{R}^{L \times d} Tokenize(xt)RL×d
  • E pos E_{\text{pos}} Epos:Patch位置编码矩阵,采用绝对位置编码显式编码每个Patch的空间坐标 ( i , j ) (i,j) (i,j)
  • E time ( t ) E_{\text{time}}(t) Etime(t):时间步嵌入向量,通过频域调制生成:
    E time ( t ) = ∑ k = 0 d / 2 − 1 [ sin ⁡ ( 1 0 4 k / d t ) , cos ⁡ ( 1 0 4 k / d t ) ] E_{\text{time}}(t) = \sum_{k=0}^{d/2-1} \left[ \sin(10^{4k/d} t), \cos(10^{4k/d} t) \right] Etime(t)=k=0d/21[sin(104k/dt),cos(104k/dt)]
    实现不同频率分量的时间信息编码,避免梯度消失。

2. 扩散过程的时空联合建模

DIT将空间(Patch序列)与时间(扩散步)通过Transformer统一建模:

  • 输入处理:图像分割为Patch后,与位置编码 E pos E_{\text{pos}} Epos和时间嵌入 E time ( t ) E_{\text{time}}(t) Etime(t)相加,作为Transformer输入。
  • 条件概率输出:Transformer输出各Patch的均值 μ θ ∈ R L × d \mu_\theta \in \mathbb{R}^{L \times d} μθRL×d,方差保持各Patch独立的高斯分布 σ t 2 I \sigma_t^2 I σt2I,保留空间结构与时间依赖的联合建模。

3. 损失函数的深层设计逻辑

1. Patch级噪声预测损失( L patch \mathcal{L}_{\text{patch}} Lpatch
L patch = E t , i , j ∥ ϵ i , j − MLP ( Attn ( Q i , j , K , V ) ) ∥ 2 \mathcal{L}_{\text{patch}} = \mathbb{E}_{t,i,j} \left\| \epsilon_{i,j} - \text{MLP}(\text{Attn}(Q_{i,j}, K, V)) \right\|^2 Lpatch=Et,i,jϵi,jMLP(Attn(Qi,j,K,V))2

  • 将噪声分解为Patch级 ϵ i , j \epsilon_{i,j} ϵi,j,通过自注意力机制捕捉全局依赖后,MLP预测噪声并计算均方误差,聚焦局部细节与全局结构的联合优化。

2. 序列相关性约束损失( L seq \mathcal{L}_{\text{seq}} Lseq
L seq = E t [ 1 N 2 ∑ i , j KL ( p θ ( z i , j ∣ z < i , j ) ∥ q ( z i , j ∣ z < i , j ) ) ] \mathcal{L}_{\text{seq}} = \mathbb{E}_{t} \left[ \frac{1}{N^2} \sum_{i,j} \text{KL}(p_\theta(z_{i,j}|z_{<i,j}) \| q(z_{i,j}|z_{<i,j})) \right] Lseq=Et[N21i,jKL(pθ(zi,jz<i,j)q(zi,jz<i,j))]

  • 引入自回归先验 q ( z i , j ∣ z < i , j ) q(z_{i,j}|z_{<i,j}) q(zi,jz<i,j)(假设Patch按行优先生成),通过KL散度约束模型生成的条件分布,确保空间结构的逻辑一致性,避免生成不连贯问题。

二、自注意力机制的扩散适应性改进

1. 传统自注意力与扩散感知改进

传统自注意力
Attention ( Q , K , V ) = Softmax ( Q K T d ) V \text{Attention}(Q,K,V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d}} \right)V Attention(Q,K,V)=Softmax(d QKT)V

DIT的扩散感知注意力

(1)时间依赖的温度系数

Temp ( t ) = 1 β t d \text{Temp}(t) = \frac{1}{\sqrt{\beta_t d}} Temp(t)=βtd 1

  • 扩散初期( t → T t \to T tT β t \beta_t βt大)噪声主导,降低温度使注意力分布更尖锐,增强全局关联;后期( t → 0 t \to 0 t0 β t \beta_t βt小)信号主导,升高温度使注意力更平滑,聚焦局部细节。
(2)噪声掩码机制

M i , j = Sigmoid ( MLP ( E time ( t ) ) ) ⋅ I ∣ i − j ∣ < k ( t ) , k ( t ) = ⌈ α t N ⌉ M_{i,j} = \text{Sigmoid}\left( \text{MLP}(E_{\text{time}}(t)) \right) \cdot \mathbb{I}_{|i-j| < k(t)}, \quad k(t) = \lceil \alpha_t N \rceil Mi,j=Sigmoid(MLP(Etime(t)))Iij<k(t),k(t)=αtN

  • 动态控制感受野: α t = ∏ s = 1 t ( 1 − β s ) \alpha_t = \sqrt{\prod_{s=1}^t (1-\beta_s)} αt=s=1t(1βs) 随时间递减, k ( t ) k(t) k(t)从全局(早期大 k ( t ) k(t) k(t))过渡到局部(后期小 k ( t ) k(t) k(t)),减少冗余计算并保留多尺度依赖。
  • 掩码应用于注意力矩阵,实现软掩码与距离掩码的结合:
    Attn ( Q , K , V , t ) = Softmax ( Q K T ⊙ M d ⋅ Temp ( t ) ) V \text{Attn}(Q,K,V,t) = \text{Softmax}\left( \frac{QK^T \odot M}{\sqrt{d} \cdot \text{Temp}(t)} \right)V Attn(Q,K,V,t)=Softmax(d Temp(t)QKTM)V

三、与传统扩散模型的对比分析

1. 架构差异对比

维度传统扩散模型(如DDPM)DIT
主干网络U-Net(卷积结构)Transformer(自注意力结构)
条件建模方式时间步 t t t拼接/添加到各层时间嵌入与位置编码共同参与注意力计算
特征交互范围局部感受野(受卷积核限制)全局交互(自注意力机制)
位置信息处理无显式编码显式Patch位置编码
参数量级通常较小(约100M参数)较大(可扩展至10B参数)

2. 理论特性对比

特性DDPM/DDIMDIT
马尔可夫性假设严格马尔可夫链可支持非马尔可夫过程
生成过程可逆性单步不可逆通过自注意力保留路径依赖信息
损失函数形式基于像素级噪声预测联合优化Patch级和序列级损失
收敛速度较慢(需1000+步采样)快速(100-200步达到同等质量)

四、实际操作与工程实现

1. 模型架构配置建议

组件常规配置可调参数说明
Patch尺寸16×16(256×256图像)小尺寸(8×8)保留细节,大尺寸(32×32)降低计算量
Transformer层数12-24层(Base版)/ 36层(Large版)深层数需搭配LayerNorm和残差连接
注意力头数16头(d=1024)头数与维度匹配,避免维度碎片化
位置编码绝对位置编码+可学习参数支持相对位置编码(需调整注意力计算)

2. 训练优化策略

  • 数据预处理:图像归一化至 [ − 1 , 1 ] [-1, 1] [1,1],随机水平翻转;时间步 t t t均匀采样于 [ 1 , T ] [1, T] [1,T](T通常1000,DIT支持更少步数)。
  • 优化器:AdamW( β 1 = 0.9 , β 2 = 0.999 \beta_1=0.9, \beta_2=0.999 β1=0.9,β2=0.999,权重衰减0.05),学习率余弦衰减(初始 1 e − 4 1e-4 1e4,热身5000步)。
  • 混合精度:使用FP16混合精度(PyTorch AMP),减少显存占用,支持更大Batch Size(如128)。

3. 推理加速技术

  • 动态跳步采样:基于非马尔可夫假设,跳过部分时间步(如从1000步降至100-200步),优先在高 β t \beta_t βt阶段大步长跳跃,后期小步长细化细节。
  • 并行化Patch生成:Transformer支持所有Patch并行预测,生成速度随序列长度线性增长,显著提升高分辨率图像(如1024×1024)生成效率。

五、关键技术对比与适用场景

1. 与U-Net扩散模型的核心差异

特性U-Net(DDPM)Transformer(DIT)
空间建模卷积归纳偏置(强)自注意力(弱归纳偏置)
长程依赖依赖跳跃连接直接全局交互(复杂度 O ( L 2 ) O(L^2) O(L2)
分辨率扩展性受限于下采样/上采样层级支持任意Patch尺寸(位置编码适配)
多模态兼容性依赖额外输入拼接自然支持序列输入(文本/图像混合)

2. 适用场景建议

  • 高分辨率图像生成:Transformer的全局建模避免卷积局部信息丢失,细节更锐利(如1024×1024+)。
  • 复杂场景合成:序列相关性约束确保物体间空间关系合理,减少语义冲突(如多物体场景)。
  • 多模态基础模型:Token化输入支持多模态统一编码,便于扩展为跨模态生成模型(如DALL-E类)。

结论

DIT通过Transformer重构扩散模型,实现生成质量与效率的双重突破:

  1. 核心创新:时空联合编码(Patch序列+时间嵌入)、动态注意力机制(温度系数+噪声掩码)、非马尔可夫生成路径(支持跳步采样)。
  2. 技术优势:FID指标超越传统模型44.7%,推理速度提升5-10倍,参数量可扩展至千亿级适配多模态场景。
  3. 工程建议:根据计算资源选择Patch尺寸与模型规模,优化自注意力内存效率(如FlashAttention),优先应用于高分辨率、复杂场景生成任务。

关键公式总结

  • 条件概率重构: p θ ( x t − 1 ∣ x t ) = N ( DiT ( x t , t ) , σ t 2 I ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(\text{DiT}(x_t, t), \sigma_t^2 I) pθ(xt1xt)=N(DiT(xt,t),σt2I)
  • 扩散感知注意力: Attn ( Q , K , V , t ) = Softmax ( Q K T ⊙ M i , j d ⋅ β t ) V \text{Attn}(Q,K,V,t) = \text{Softmax}\left( \frac{QK^T \odot M_{i,j}}{\sqrt{d} \cdot \sqrt{\beta_t}} \right)V Attn(Q,K,V,t)=Softmax(d βt QKTMi,j)V

代码实现参考
Meta官方仓库:https://github.com/facebookresearch/DiT
重点优化:Patch分割(nn.Unfold)、时间嵌入(频域编码)、动态注意力(掩码矩阵与温度系数)。

http://www.xdnf.cn/news/3211.html

相关文章:

  • Java继承中super的使用方法
  • SI5338-EVB Usage Guide(LVPECL、LVDS、HCSL、CMOS、SSTL、HSTL)
  • 电子病历高质量语料库构建方法与架构项目(智能数据目录篇)
  • SD - WAN 跨境网络专线部署方式介绍
  • 大数据在远程医疗中的创新应用:如何重塑医疗行业的未来
  • python + segno 生成个人二维码
  • 全球气象站点年平均降水数据(1929-2024)
  • 大连理工大学选修课——机器学习笔记(4):NBM的原理及应用
  • 大连理工大学选修课——机器学习笔记(9):线性判别式与逻辑回归
  • 使用 ossutil 上传文件到阿里云 OSS
  • 基于连接感知的实时困倦分类图神经网络
  • 【数学】角谷猜想
  • 服务器热备份,服务器热备份的方法有哪些?
  • 猿人学web端爬虫攻防大赛赛题第13题——入门级cookie
  • 完美解决react-native文件直传阿里云oss问题一
  • Android学习总结之自定义view设计模式理解
  • Redis热key大key详解
  • ESP32开发-通过ENC28J60模块实现以太网设备
  • 从实列中学习linux shell6: 写一个 shell 脚本 过滤 恶意ip 攻击
  • css 数字从0开始增加的动画效果
  • 【数学建模国奖速成系列】优秀论文绘图复现代码(二)
  • DeepSeek V1:初代模型的架构与性能
  • 艺术与科技的双向奔赴——高一鑫荣获加州联合表彰
  • Java ResourceBundle 资源绑定详解
  • 腾讯元宝桌面客户端:基于Tauri的开源技术解析
  • Python GIL 与 pybind11 GIL管理机制
  • 模拟flexible.js 前端开发中的大屏布局方案
  • Hadoop虚拟机中配置hosts
  • 评价类模型数据预处理(定量指标值的无量纲化处理)
  • 从零构建 MCP Server 与 Client:打造你的第一个 AI 工具集成应用