当前位置: 首页 > news >正文

SkyReels-V2 视频生成

SkyReels-V2 视频生成

flyfish

扩散强制(DF)模型:专为无限长度视频生成设计,提供1.3B-540P和14B-720P等版本
文本到视频(T2V)模型:专注于从文本提示生成高质量视频
图像到视频(I2V)模型:能够从输入图像生成连贯的视频序列

import os
# 设置TOKENIZERS_PARALLELISM为false,避免分词器并行化可能带来的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import gc
import os
import random
import time
import jsonimport imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop# 单例模式元类,确保类只有一个实例
class Singleton(type):_instances = {}def __call__(cls, *args, **kwargs):if cls not in cls._instances:cls._instances[cls] = super().__call__(*args, **kwargs)return cls._instances[cls]# 配置解析类,用于解析命令行参数
class ConfigParser:def __init__(self):self.parser = argparse.ArgumentParser()self._add_arguments()def _add_arguments(self):# 输出目录self.parser.add_argument("--outdir", type=str, default="diffusion_forcing")# 模型IDself.parser.add_argument("--model_id",type=str,default="/media/models/Skywork/SkyReels-V2-DF-1___3B-540P/",)# 分辨率self.parser.add_argument("--resolution", type=str, default="540P", choices=["540P", "720P"])# 帧数self.parser.add_argument("--num_frames", type=int, default=97)# 图像路径self.parser.add_argument("--image", type=str, default=None)# AR步骤self.parser.add_argument("--ar_step", type=int, default=0)# 是否使用因果注意力self.parser.add_argument("--causal_attention", action="store_true")# 因果块大小self.parser.add_argument("--causal_block_size", type=int, default=1)# 基础帧数self.parser.add_argument("--base_num_frames", type=int, default=97)# 重叠历史self.parser.add_argument("--overlap_history", type=int, default=None)# 添加噪声条件self.parser.add_argument("--addnoise_condition", type=int, default=0)# 引导比例self.parser.add_argument("--guidance_scale", type=float, default=6.0)# 偏移量self.parser.add_argument("--shift", type=float, default=8.0)# 推理步骤self.parser.add_argument("--inference_steps", type=int, default=30)  # 30# 是否使用USPself.parser.add_argument("--use_usp", action="store_true")# 是否卸载self.parser.add_argument("--offload", action="store_true")# 帧率self.parser.add_argument("--fps", type=int, default=24)# 随机种子self.parser.add_argument("--seed", type=int, default=None)# 提示文件self.parser.add_argument("--prompt", type=str, default="prompt.json")# 是否使用提示增强器self.parser.add_argument("--prompt_enhancer", action="store_true")# 是否使用TEA缓存self.parser.add_argument("--teacache", action="store_true")# TEA缓存阈值self.parser.add_argument("--teacache_thresh",type=float,default=0.2,help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",)# 是否使用保留步骤self.parser.add_argument("--use_ret_steps",action="store_true",help="Using Retention Steps will result in faster generation speed and better generation quality.",)def parse(self):return self.parser.parse_args()# 环境设置类,用于设置运行环境
class EnvironmentSetup:def __init__(self, args):self.args = argsself._validate_seed()self._set_resolution()self._validate_num_frames()self._validate_addnoise_condition()# 负提示词self.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"self._create_save_dir()self._setup_usp()def _validate_seed(self):# 验证种子是否有效,USP模式需要种子assert (self.args.use_usp and self.args.seed is not None) or (not self.args.use_usp), "usp mode need seed"if self.args.seed is None:random.seed(time.time())self.args.seed = int(random.randrange(4294967294))def _set_resolution(self):# 根据分辨率参数设置高度和宽度if self.args.resolution == "540P":self.height = 544self.width = 960elif self.args.resolution == "720P":self.height = 720self.width = 1280else:raise ValueError(f"Invalid resolution: {self.args.resolution}")def _validate_num_frames(self):# 验证帧数是否有效,长视频生成需要指定重叠历史if self.args.num_frames > self.args.base_num_frames:assert (self.args.overlap_history is not None), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'def _validate_addnoise_condition(self):# 验证添加噪声条件是否有效,值过大可能导致长视频生成不一致if self.args.addnoise_condition > 60:print(f'You have set "addnoise_condition" as {self.args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.')def _create_save_dir(self):# 创建保存目录self.save_dir = os.path.join("result", self.args.outdir)os.makedirs(self.save_dir, exist_ok=True)def _setup_usp(self):self.local_rank = 0if self.args.use_usp:# USP模式下不允许使用提示增强器assert (not self.args.prompt_enhancer), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."from xfuser.core.distributed import (initialize_model_parallel,init_distributed_environment,)import torch.distributed as dist# 初始化分布式环境dist.init_process_group("nccl")self.local_rank = dist.get_rank()torch.cuda.set_device(dist.get_rank())self.device = "cuda"init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),ring_degree=1,ulysses_degree=dist.get_world_size(),)# 管道设置类,用于创建和配置DiffusionForcingPipeline
class PipelineSetup(metaclass=Singleton):def __init__(self, args):self.pipe = DiffusionForcingPipeline(args.model_id,dit_path=args.model_id,device=torch.device("cuda"),weight_dtype=torch.bfloat16,use_usp=args.use_usp,offload=args.offload,)if args.causal_attention:# 设置因果注意力self.pipe.transformer.set_ar_attention(args.causal_block_size)if args.teacache:if args.ar_step > 0:# 计算推理步骤数num_steps = (args.inference_steps+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1)* args.ar_step)print("num_steps:", num_steps)else:num_steps = args.inference_steps# 初始化TEA缓存self.pipe.transformer.initialize_teacache(enable_teacache=True,num_steps=num_steps,teacache_thresh=args.teacache_thresh,use_ret_steps=args.use_ret_steps,ckpt_dir=args.model_id,)# 提示加载类,用于加载提示信息
class PromptLoader:def __init__(self, args):self.args = argsdef load(self):# 加载提示文件,如果文件存在且为JSON格式,则解析JSON文件,否则返回默认提示if os.path.exists(self.args.prompt) and self.args.prompt.endswith(".json"):with open(self.args.prompt, "r", encoding="utf-8") as f:return json.load(f)return [{"prompt": self.args.prompt}]# 提示增强包装类,用于增强提示信息
class PromptEnhancerWrapper:def __init__(self, args):self.args = argsdef enhance(self, prompt_input, image):if self.args.prompt_enhancer and image is None:print(f"init prompt enhancer")prompt_enhancer = PromptEnhancer()# 增强提示信息prompt_input = prompt_enhancer(prompt_input)print(f"enhanced prompt: {prompt_input}")del prompt_enhancergc.collect()torch.cuda.empty_cache()return prompt_input# 视频生成类,用于生成视频帧
class VideoGenerator:def __init__(self, pipe):self.pipe = pipedef generate(self, prompt_input, negative_prompt, image, height, width, num_frames,num_inference_steps, shift, guidance_scale, seed, overlap_history,addnoise_condition, base_num_frames, ar_step, causal_block_size, fps):with torch.cuda.amp.autocast(dtype=self.pipe.transformer.dtype), torch.no_grad():# 生成视频帧return self.pipe(prompt=prompt_input,negative_prompt=negative_prompt,image=image,height=height,width=width,num_frames=num_frames,num_inference_steps=num_inference_steps,shift=shift,guidance_scale=guidance_scale,generator=torch.Generator(device="cuda").manual_seed(seed),overlap_history=overlap_history,addnoise_condition=addnoise_condition,base_num_frames=base_num_frames,ar_step=ar_step,causal_block_size=causal_block_size,fps=fps,)[0]# 视频保存类,用于保存生成的视频
class VideoSaver:def save(self, video_frames, save_dir, prompt_input, seed, fps):# 生成视频文件名current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())video_out_file = f"{prompt_input[:100].replace('/','')}_{seed}_{current_time}.mp4"output_path = os.path.join(save_dir, video_out_file)# 保存视频imageio.mimwrite(output_path,video_frames,fps=fps,quality=8,output_params=["-loglevel", "error"],)# 视频生成应用类,作为应用程序的入口点,协调各个类的工作class VideoGenerationApp:def __init__(self):self.config_parser = ConfigParser()self.args = self.config_parser.parse()self.env_setup = EnvironmentSetup(self.args)self.pipeline_setup = PipelineSetup(self.args)self.prompt_loader = PromptLoader(self.args)self.prompt_enhancer = PromptEnhancerWrapper(self.args)self.video_generator = VideoGenerator(self.pipeline_setup.pipe)self.video_saver = VideoSaver()# 新增统计相关属性self.video_count = 0self.current_video = 0self.time_records = []self.total_time = 0.0def run(self):prompts = self.prompt_loader.load()self.video_count = len(prompts)self.current_video = 0self.time_records.clear()self.total_time = 0.0for prompt_info in prompts:self.current_video += 1start_time = time.perf_counter()  # 记录开始时间prompt_input = prompt_info["prompt"]image = Noneif "image_paths" in prompt_info:image_path = prompt_info["image_paths"][0]image = load_image(image_path)image_width, image_height = image.sizeif image_height > image_width:self.env_setup.height, self.env_setup.width = self.env_setup.width, self.env_setup.heightimage = resizecrop(image, self.env_setup.height, self.env_setup.width)image = image.convert("RGB")prompt_input = self.prompt_enhancer.enhance(prompt_input, image)print(f"\n=== Video {self.current_video}/{self.video_count} ===")print(f"Prompt: {prompt_input[:100]}...")print(f"Guidance Scale: {self.env_setup.args.guidance_scale}")video_frames = self.video_generator.generate(prompt_input,self.env_setup.negative_prompt,image,self.env_setup.height,self.env_setup.width,self.env_setup.args.num_frames,self.env_setup.args.inference_steps,self.env_setup.args.shift,self.env_setup.args.guidance_scale,self.env_setup.args.seed,self.env_setup.args.overlap_history,self.env_setup.args.addnoise_condition,self.env_setup.args.base_num_frames,self.env_setup.args.ar_step,self.env_setup.args.causal_block_size,self.env_setup.args.fps,)# 计算本次推理时间duration = time.perf_counter() - start_timeself.time_records.append(duration)self.total_time += durationprint(f"Generation completed in {duration:.2f} seconds")if self.env_setup.local_rank == 0:self.video_saver.save(video_frames, self.env_setup.save_dir, prompt_input, self.env_setup.args.seed,self.env_setup.args.fps)print(f"Video saved to {self.env_setup.save_dir}")# 新增统计结果汇总if self.env_setup.local_rank == 0 and self.video_count > 0:avg_time = self.total_time / self.video_countmax_time = max(self.time_records) if self.time_records else 0print("\n=== Generation Statistics ===")print(f"Total Videos: {self.video_count}")print(f"Total Time: {self.total_time:.2f} seconds")print(f"Average Time per Video: {avg_time:.2f} seconds")print(f"Max Time per Video: {max_time:.2f} seconds")if __name__ == "__main__":app = VideoGenerationApp()app.run()

使用说明

1. 环境准备
  • 模型路径:将--model_id参数指向正确的模型目录(默认路径为示例路径,需根据实际情况修改)。
  • 提示文件:准备提示文件(默认prompt.json),格式为JSON,支持多提示输入:
    [{"prompt": "your first prompt", "image_paths": ["image1.jpg"]},{"prompt": "your second prompt", "image_paths": ["image2.jpg"]}
    ]
    
2. 关键参数说明
参数名功能描述
--outdir输出目录,视频将保存在result/{outdir}下。
--resolution视频分辨率,支持540P(544x960)和720P(720x1280)。
--num_frames生成视频的总帧数(长视频需配合--overlap_history参数)。
--prompt提示文件路径(JSON格式)或直接输入提示词(非JSON时默认使用单提示)。
--seed随机种子(固定种子可复现结果,--use_usp模式下必须设置)。
--guidance_scale生成质量控制参数(值越大越贴近提示,建议6.0-8.0)。
--image初始图像路径(可选,用于图像生成视频)。
--use_usp启用分布式模式(需多GPU支持,需提前初始化分布式环境)。
3. 运行命令
python script_name.py [参数列表]
  • 示例:生成一个720P、97帧、使用默认提示的视频:
    python video_generation_refactored.py --resolution 720P --num_frames 97
    
    使用默认
  python video_generation_refactored.py  --prompt  prompt.json

思路

1. 模块化设计(单一职责原则)

将复杂功能拆解为独立类,每个类专注于单一职责:

  • ConfigParser:解析命令行参数,统一管理输入配置。
  • EnvironmentSetup:验证参数合法性、设置运行环境(分辨率、保存目录、分布式配置等)。
  • PipelineSetup:初始化模型管道(DiffusionForcingPipeline),配置推理参数(因果注意力、TEA缓存等)。
  • PromptLoader/Enhancer:加载提示文件并按需增强提示词(提升生成效果)。
  • VideoGenerator/Saver:分离视频生成和保存逻辑,解耦核心功能与IO操作。
2. 单例模式(资源优化)
  • PipelineSetup使用单例模式:确保模型管道全局唯一,避免重复加载模型浪费内存,提升效率。
  • 适用场景:模型体积大、初始化耗时,单例模式保证内存中仅存在一个实例。
3. 分布式支持(扩展性)
  • --use_usp参数:支持多GPU分布式推理,通过xfuser库初始化分布式环境,提升大规模生成效率。
  • 约束机制:分布式模式下禁止使用提示增强器(--prompt_enhancer),确保逻辑一致性。
4. 鲁棒性设计(参数验证与异常处理)
  • 强参数校验
    • USP模式强制要求种子(--seed),避免随机初始化导致的分布式不一致。
    • 长视频生成(--num_frames > base_num_frames)强制要求--overlap_history,确保时序连贯性。
    • 分辨率严格限制为540P/720P,避免无效输入。
  • 内存管理:提示增强器使用后手动释放资源(gc.collect()),避免内存泄漏。
5. 流程解耦与协作
  • VideoGenerationApp作为协调者:串联各模块,按“加载配置→初始化环境→生成视频→保存结果”的流程执行。
  • 依赖注入:通过类构造函数传递依赖(如管道对象、配置参数),降低模块间耦合度。
6. 输出与可追溯性
  • 自动生成文件名:包含提示词(前100字)、种子、时间戳,便于区分不同生成任务。
  • 保存目录结构:统一输出到result/{outdir},支持断点续传(exist_ok=True)。
http://www.xdnf.cn/news/379081.html

相关文章:

  • Cadence 高速系统设计流程及工具使用三
  • 加速pip下载:永久解决网络慢问题
  • 数据集-目标检测系列- 冥想 检测数据集 close_eye>> DataBall
  • AI实战笔记(1)AI 的 6 大核心方向 + 学习阶段路径
  • Linxu实验五——NFS服务器
  • WordPress插件targetsms存在远程命令执行漏洞(CVE-2025-3776)
  • 20250510-查看 Anaconda 配置的镜像源
  • redis未授权访问
  • [架构之美]从零开始整合Spring Boot与Maven(十五)
  • AUTODL Chatglm2 langchain 部署大模型聊天助手
  • C语言初阶秘籍6
  • 二分法和牛顿迭代法解方程实根,详解
  • 第十九节:图像梯度与边缘检测- Laplacian 算子
  • 「OC」源码学习——cache_t的原理探究
  • C32-编程案例用函数封装获取两个数的较大数
  • IPFS与去中心化存储:重塑数字世界的基石
  • nuscenes_devkit工具
  • Windows:Powershell的使用
  • 进阶二:基于HC-SR04和LCD1602的超声波测距
  • 海纳思(Hi3798MV300)机顶盒遇到海思摄像头
  • 贪心算法专题(Part1)
  • AI大模型学习十七、利用Dify搭建 AI 图片生成应用
  • STL-to-ASCII-Generator 实用教程
  • SpringBoot2集成xxl-job详解
  • 大模型微调指南之 LLaMA-Factory 篇:一键启动LLaMA系列模型高效微调
  • 差动讯号(3)弱耦合与强耦合
  • Linux数据库篇、第一章_01MySQL5.7的安装部署
  • Java基础 5.10
  • 致远A8V5-9.0安装包(包含信创版)【附百度网盘链接】
  • LeetCode 热题 100 24. 两两交换链表中的节点