当前位置: 首页 > ai >正文

基于Stable Diffusion XL模型进行文本生成图像的训练

基于Stable Diffusion XL模型进行文本生成图像的训练

flyfish

环境变量部分

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
  • MODEL_NAME:指定预训练模型的名称或路径。这里使用的是 stabilityai/stable-diffusion-xl-base-1.0,也就是Stable Diffusion XL的基础版本1.0。
  • VAE_NAME:指定变分自编码器(VAE)的名称或路径。madebyollin/sdxl-vae-fp16-fix 是针对Stable Diffusion XL的一个经过修复的VAE模型,适用于半精度(FP16)计算。
  • DATASET_NAME:指定训练所使用的数据集名称或路径。这里使用的是 lambdalabs/naruto-blip-captions,是一个包含火影忍者相关图像及其描述的数据集。

accelerate launch 命令参数部分

accelerate launch train_text_to_image_sdxl.py \

这行代码使用 accelerate 工具来启动 train_text_to_image_sdxl.py 脚本,accelerate 可以帮助我们在多GPU、TPU等环境下进行分布式训练。

脚本参数部分

  • --pretrained_model_name_or_path=$MODEL_NAME:指定预训练模型的名称或路径,这里使用前面定义的 MODEL_NAME 环境变量。
  • --pretrained_vae_model_name_or_path=$VAE_NAME:指定预训练VAE模型的名称或路径,使用前面定义的 VAE_NAME 环境变量。
  • --dataset_name=$DATASET_NAME:指定训练数据集的名称或路径,使用前面定义的 DATASET_NAME 环境变量。
  • --enable_xformers_memory_efficient_attention:启用 xformers 库的内存高效注意力机制,能减少训练过程中的内存占用。
  • --resolution=512 --center_crop --random_flip
    • --resolution=512:将输入图像的分辨率统一调整为512x512像素。
    • --center_crop:对图像进行中心裁剪,使其达到指定的分辨率。
    • --random_flip:在训练过程中随机对图像进行水平翻转,以增加数据的多样性。
  • --proportion_empty_prompts=0.2:设置空提示(没有文本描述)的样本在训练数据中的比例为20%。
  • --train_batch_size=1:每个训练批次包含的样本数量为1。
  • --gradient_accumulation_steps=4 --gradient_checkpointing
    • --gradient_accumulation_steps=4:梯度累积步数为4,即每4个批次的梯度进行一次更新,这样可以在有限的内存下模拟更大的批次大小。
    • --gradient_checkpointing:启用梯度检查点机制,通过减少内存使用来支持更大的模型和批次大小。
  • --max_train_steps=10000:最大训练步数为10000步。
  • --use_8bit_adam:使用8位Adam优化器,能减少内存占用。
  • --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
    • --learning_rate=1e-06:学习率设置为1e-6。
    • --lr_scheduler="constant":学习率调度器设置为常数,即训练过程中学习率保持不变。
    • --lr_warmup_steps=0:学习率预热步数为0,即不进行学习率预热。
  • --mixed_precision="fp16":使用半精度(FP16)混合精度训练,能减少内存使用并加快训练速度。
  • --report_to="wandb":将训练过程中的指标报告到Weights & Biases(WandB)平台,方便进行可视化和监控。
  • --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
    • --validation_prompt="a cute Sundar Pichai creature":指定验证时使用的文本提示,这里是“一个可爱的桑达尔·皮查伊形象”。
    • --validation_epochs 5:每5个训练轮次进行一次验证。
  • --checkpointing_steps=5000:每5000步保存一次模型的检查点。
  • --output_dir="sdxl-naruto-model":指定训练好的模型的输出目录为 sdxl-naruto-model
  • --push_to_hub:将训练好的模型推送到Hugging Face模型库。

离线环境运行

# 假设已经把模型、VAE和数据集下载到本地了
# 这里假设模型在当前目录下的 sdxl-base-1.0 文件夹
# VAE 在 sdxl-vae-fp16-fix 文件夹
# 数据集在 naruto-blip-captions 文件夹# 定义本地路径
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"# 移除需要外网连接的参数
accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model"

移除需要外网连接的参数:去掉 --report_to="wandb"--push_to_hub 参数,因为 wandb 需要外网连接来上传训练指标,--push_to_hub 则需要外网连接把模型推送到Hugging Face模型库。

http://www.xdnf.cn/news/4631.html

相关文章:

  • 《社交应用架构生存战:React Native与Flutter的部署容灾决胜法则》
  • k8s(11) — 探针和钩子
  • SpringBoot学生操行评分系统源码设计开发
  • C++函数传值与传引用对比分析
  • 课外活动:简单了解原生测试框架Unittest前置后置的逻辑
  • 录播课视觉包装与转化率提升指南
  • 【NextPilot日志移植】整体功能概要
  • 迪士尼机器人BD-X 概况
  • 5G + AR:让增强现实真正“实时交互”起来
  • 前端取经路——框架修行:React与Vue的双修之路
  • 数据来源合法性尽职调查:保障权益的关键防线
  • Android不能下载Gradle,解决方法Could not install Gradle distribution from.......
  • 2025最新:3分钟使用Docker快速部署单节点Redis
  • python+open3d获取点云的最小外接球体及使用球体裁剪点云
  • 蓝桥杯青少 图形化编程(Scratch)每日一练——校门外的树
  • VGGNet详解
  • java集成telegram机器人
  • [特殊字符]【实战教程】用大模型LLM查询Neo4j图数据库(附完整代码)
  • 赋能金融科技创新,Telerik打造高效、安全的金融应用解决方案!
  • Linux58 ssh服务配置 jumpserver 测试双网卡 为何不能ping通ip地址
  • 从ellisys空口分析蓝牙耳机回连手机失败案例
  • 正则表达式(Regular Expression)详解
  • 关于ubuntu下交叉编译arrch64下的gtsam报错问题,boost中boost_regex.so中连接libicui18n.so.55报错的问题
  • 【Python 字符串】
  • Java常用API:深度解析与实践应用
  • 【Spring Boot 多模块项目】@MapperScan失效、MapperScannerConfigurer 报错终极解决方案
  • 安装 Docker
  • ZC706开发板教程:windows下编译ADRV9009
  • vue 中如何使用region?
  • PyTorch 实战:从 0 开始搭建 Transformer