SkyReels-V2 视频生成
SkyReels-V2 视频生成
flyfish
扩散强制(DF)模型:专为无限长度视频生成设计,提供1.3B-540P和14B-720P等版本
文本到视频(T2V)模型:专注于从文本提示生成高质量视频
图像到视频(I2V)模型:能够从输入图像生成连贯的视频序列
import os
# 设置TOKENIZERS_PARALLELISM为false,避免分词器并行化可能带来的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import gc
import os
import random
import time
import jsonimport imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop# 单例模式元类,确保类只有一个实例
class Singleton(type):_instances = {}def __call__(cls, *args, **kwargs):if cls not in cls._instances:cls._instances[cls] = super().__call__(*args, **kwargs)return cls._instances[cls]# 配置解析类,用于解析命令行参数
class ConfigParser:def __init__(self):self.parser = argparse.ArgumentParser()self._add_arguments()def _add_arguments(self):# 输出目录self.parser.add_argument("--outdir", type=str, default="diffusion_forcing")# 模型IDself.parser.add_argument("--model_id",type=str,default="/media/models/Skywork/SkyReels-V2-DF-1___3B-540P/",)# 分辨率self.parser.add_argument("--resolution", type=str, default="540P", choices=["540P", "720P"])# 帧数self.parser.add_argument("--num_frames", type=int, default=97)# 图像路径self.parser.add_argument("--image", type=str, default=None)# AR步骤self.parser.add_argument("--ar_step", type=int, default=0)# 是否使用因果注意力self.parser.add_argument("--causal_attention", action="store_true")# 因果块大小self.parser.add_argument("--causal_block_size", type=int, default=1)# 基础帧数self.parser.add_argument("--base_num_frames", type=int, default=97)# 重叠历史self.parser.add_argument("--overlap_history", type=int, default=None)# 添加噪声条件self.parser.add_argument("--addnoise_condition", type=int, default=0)# 引导比例self.parser.add_argument("--guidance_scale", type=float, default=6.0)# 偏移量self.parser.add_argument("--shift", type=float, default=8.0)# 推理步骤self.parser.add_argument("--inference_steps", type=int, default=30) # 30# 是否使用USPself.parser.add_argument("--use_usp", action="store_true")# 是否卸载self.parser.add_argument("--offload", action="store_true")# 帧率self.parser.add_argument("--fps", type=int, default=24)# 随机种子self.parser.add_argument("--seed", type=int, default=None)# 提示文件self.parser.add_argument("--prompt", type=str, default="prompt.json")# 是否使用提示增强器self.parser.add_argument("--prompt_enhancer", action="store_true")# 是否使用TEA缓存self.parser.add_argument("--teacache", action="store_true")# TEA缓存阈值self.parser.add_argument("--teacache_thresh",type=float,default=0.2,help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",)# 是否使用保留步骤self.parser.add_argument("--use_ret_steps",action="store_true",help="Using Retention Steps will result in faster generation speed and better generation quality.",)def parse(self):return self.parser.parse_args()# 环境设置类,用于设置运行环境
class EnvironmentSetup:def __init__(self, args):self.args = argsself._validate_seed()self._set_resolution()self._validate_num_frames()self._validate_addnoise_condition()# 负提示词self.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"self._create_save_dir()self._setup_usp()def _validate_seed(self):# 验证种子是否有效,USP模式需要种子assert (self.args.use_usp and self.args.seed is not None) or (not self.args.use_usp), "usp mode need seed"if self.args.seed is None:random.seed(time.time())self.args.seed = int(random.randrange(4294967294))def _set_resolution(self):# 根据分辨率参数设置高度和宽度if self.args.resolution == "540P":self.height = 544self.width = 960elif self.args.resolution == "720P":self.height = 720self.width = 1280else:raise ValueError(f"Invalid resolution: {self.args.resolution}")def _validate_num_frames(self):# 验证帧数是否有效,长视频生成需要指定重叠历史if self.args.num_frames > self.args.base_num_frames:assert (self.args.overlap_history is not None), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'def _validate_addnoise_condition(self):# 验证添加噪声条件是否有效,值过大可能导致长视频生成不一致if self.args.addnoise_condition > 60:print(f'You have set "addnoise_condition" as {self.args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.')def _create_save_dir(self):# 创建保存目录self.save_dir = os.path.join("result", self.args.outdir)os.makedirs(self.save_dir, exist_ok=True)def _setup_usp(self):self.local_rank = 0if self.args.use_usp:# USP模式下不允许使用提示增强器assert (not self.args.prompt_enhancer), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."from xfuser.core.distributed import (initialize_model_parallel,init_distributed_environment,)import torch.distributed as dist# 初始化分布式环境dist.init_process_group("nccl")self.local_rank = dist.get_rank()torch.cuda.set_device(dist.get_rank())self.device = "cuda"init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),ring_degree=1,ulysses_degree=dist.get_world_size(),)# 管道设置类,用于创建和配置DiffusionForcingPipeline
class PipelineSetup(metaclass=Singleton):def __init__(self, args):self.pipe = DiffusionForcingPipeline(args.model_id,dit_path=args.model_id,device=torch.device("cuda"),weight_dtype=torch.bfloat16,use_usp=args.use_usp,offload=args.offload,)if args.causal_attention:# 设置因果注意力self.pipe.transformer.set_ar_attention(args.causal_block_size)if args.teacache:if args.ar_step > 0:# 计算推理步骤数num_steps = (args.inference_steps+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1)* args.ar_step)print("num_steps:", num_steps)else:num_steps = args.inference_steps# 初始化TEA缓存self.pipe.transformer.initialize_teacache(enable_teacache=True,num_steps=num_steps,teacache_thresh=args.teacache_thresh,use_ret_steps=args.use_ret_steps,ckpt_dir=args.model_id,)# 提示加载类,用于加载提示信息
class PromptLoader:def __init__(self, args):self.args = argsdef load(self):# 加载提示文件,如果文件存在且为JSON格式,则解析JSON文件,否则返回默认提示if os.path.exists(self.args.prompt) and self.args.prompt.endswith(".json"):with open(self.args.prompt, "r", encoding="utf-8") as f:return json.load(f)return [{"prompt": self.args.prompt}]# 提示增强包装类,用于增强提示信息
class PromptEnhancerWrapper:def __init__(self, args):self.args = argsdef enhance(self, prompt_input, image):if self.args.prompt_enhancer and image is None:print(f"init prompt enhancer")prompt_enhancer = PromptEnhancer()# 增强提示信息prompt_input = prompt_enhancer(prompt_input)print(f"enhanced prompt: {prompt_input}")del prompt_enhancergc.collect()torch.cuda.empty_cache()return prompt_input# 视频生成类,用于生成视频帧
class VideoGenerator:def __init__(self, pipe):self.pipe = pipedef generate(self, prompt_input, negative_prompt, image, height, width, num_frames,num_inference_steps, shift, guidance_scale, seed, overlap_history,addnoise_condition, base_num_frames, ar_step, causal_block_size, fps):with torch.cuda.amp.autocast(dtype=self.pipe.transformer.dtype), torch.no_grad():# 生成视频帧return self.pipe(prompt=prompt_input,negative_prompt=negative_prompt,image=image,height=height,width=width,num_frames=num_frames,num_inference_steps=num_inference_steps,shift=shift,guidance_scale=guidance_scale,generator=torch.Generator(device="cuda").manual_seed(seed),overlap_history=overlap_history,addnoise_condition=addnoise_condition,base_num_frames=base_num_frames,ar_step=ar_step,causal_block_size=causal_block_size,fps=fps,)[0]# 视频保存类,用于保存生成的视频
class VideoSaver:def save(self, video_frames, save_dir, prompt_input, seed, fps):# 生成视频文件名current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())video_out_file = f"{prompt_input[:100].replace('/','')}_{seed}_{current_time}.mp4"output_path = os.path.join(save_dir, video_out_file)# 保存视频imageio.mimwrite(output_path,video_frames,fps=fps,quality=8,output_params=["-loglevel", "error"],)# 视频生成应用类,作为应用程序的入口点,协调各个类的工作class VideoGenerationApp:def __init__(self):self.config_parser = ConfigParser()self.args = self.config_parser.parse()self.env_setup = EnvironmentSetup(self.args)self.pipeline_setup = PipelineSetup(self.args)self.prompt_loader = PromptLoader(self.args)self.prompt_enhancer = PromptEnhancerWrapper(self.args)self.video_generator = VideoGenerator(self.pipeline_setup.pipe)self.video_saver = VideoSaver()# 新增统计相关属性self.video_count = 0self.current_video = 0self.time_records = []self.total_time = 0.0def run(self):prompts = self.prompt_loader.load()self.video_count = len(prompts)self.current_video = 0self.time_records.clear()self.total_time = 0.0for prompt_info in prompts:self.current_video += 1start_time = time.perf_counter() # 记录开始时间prompt_input = prompt_info["prompt"]image = Noneif "image_paths" in prompt_info:image_path = prompt_info["image_paths"][0]image = load_image(image_path)image_width, image_height = image.sizeif image_height > image_width:self.env_setup.height, self.env_setup.width = self.env_setup.width, self.env_setup.heightimage = resizecrop(image, self.env_setup.height, self.env_setup.width)image = image.convert("RGB")prompt_input = self.prompt_enhancer.enhance(prompt_input, image)print(f"\n=== Video {self.current_video}/{self.video_count} ===")print(f"Prompt: {prompt_input[:100]}...")print(f"Guidance Scale: {self.env_setup.args.guidance_scale}")video_frames = self.video_generator.generate(prompt_input,self.env_setup.negative_prompt,image,self.env_setup.height,self.env_setup.width,self.env_setup.args.num_frames,self.env_setup.args.inference_steps,self.env_setup.args.shift,self.env_setup.args.guidance_scale,self.env_setup.args.seed,self.env_setup.args.overlap_history,self.env_setup.args.addnoise_condition,self.env_setup.args.base_num_frames,self.env_setup.args.ar_step,self.env_setup.args.causal_block_size,self.env_setup.args.fps,)# 计算本次推理时间duration = time.perf_counter() - start_timeself.time_records.append(duration)self.total_time += durationprint(f"Generation completed in {duration:.2f} seconds")if self.env_setup.local_rank == 0:self.video_saver.save(video_frames, self.env_setup.save_dir, prompt_input, self.env_setup.args.seed,self.env_setup.args.fps)print(f"Video saved to {self.env_setup.save_dir}")# 新增统计结果汇总if self.env_setup.local_rank == 0 and self.video_count > 0:avg_time = self.total_time / self.video_countmax_time = max(self.time_records) if self.time_records else 0print("\n=== Generation Statistics ===")print(f"Total Videos: {self.video_count}")print(f"Total Time: {self.total_time:.2f} seconds")print(f"Average Time per Video: {avg_time:.2f} seconds")print(f"Max Time per Video: {max_time:.2f} seconds")if __name__ == "__main__":app = VideoGenerationApp()app.run()
使用说明
1. 环境准备
- 模型路径:将
--model_id
参数指向正确的模型目录(默认路径为示例路径,需根据实际情况修改)。 - 提示文件:准备提示文件(默认
prompt.json
),格式为JSON,支持多提示输入:[{"prompt": "your first prompt", "image_paths": ["image1.jpg"]},{"prompt": "your second prompt", "image_paths": ["image2.jpg"]} ]
2. 关键参数说明
参数名 | 功能描述 |
---|---|
--outdir | 输出目录,视频将保存在result/{outdir} 下。 |
--resolution | 视频分辨率,支持540P (544x960)和720P (720x1280)。 |
--num_frames | 生成视频的总帧数(长视频需配合--overlap_history 参数)。 |
--prompt | 提示文件路径(JSON格式)或直接输入提示词(非JSON时默认使用单提示)。 |
--seed | 随机种子(固定种子可复现结果,--use_usp 模式下必须设置)。 |
--guidance_scale | 生成质量控制参数(值越大越贴近提示,建议6.0-8.0)。 |
--image | 初始图像路径(可选,用于图像生成视频)。 |
--use_usp | 启用分布式模式(需多GPU支持,需提前初始化分布式环境)。 |
3. 运行命令
python script_name.py [参数列表]
- 示例:生成一个720P、97帧、使用默认提示的视频:
使用默认python video_generation_refactored.py --resolution 720P --num_frames 97
python video_generation_refactored.py --prompt prompt.json
思路
1. 模块化设计(单一职责原则)
将复杂功能拆解为独立类,每个类专注于单一职责:
ConfigParser
:解析命令行参数,统一管理输入配置。EnvironmentSetup
:验证参数合法性、设置运行环境(分辨率、保存目录、分布式配置等)。PipelineSetup
:初始化模型管道(DiffusionForcingPipeline
),配置推理参数(因果注意力、TEA缓存等)。PromptLoader/Enhancer
:加载提示文件并按需增强提示词(提升生成效果)。VideoGenerator/Saver
:分离视频生成和保存逻辑,解耦核心功能与IO操作。
2. 单例模式(资源优化)
PipelineSetup
使用单例模式:确保模型管道全局唯一,避免重复加载模型浪费内存,提升效率。- 适用场景:模型体积大、初始化耗时,单例模式保证内存中仅存在一个实例。
3. 分布式支持(扩展性)
--use_usp
参数:支持多GPU分布式推理,通过xfuser
库初始化分布式环境,提升大规模生成效率。- 约束机制:分布式模式下禁止使用提示增强器(
--prompt_enhancer
),确保逻辑一致性。
4. 鲁棒性设计(参数验证与异常处理)
- 强参数校验:
- USP模式强制要求种子(
--seed
),避免随机初始化导致的分布式不一致。 - 长视频生成(
--num_frames > base_num_frames
)强制要求--overlap_history
,确保时序连贯性。 - 分辨率严格限制为
540P
/720P
,避免无效输入。
- USP模式强制要求种子(
- 内存管理:提示增强器使用后手动释放资源(
gc.collect()
),避免内存泄漏。
5. 流程解耦与协作
VideoGenerationApp
作为协调者:串联各模块,按“加载配置→初始化环境→生成视频→保存结果”的流程执行。- 依赖注入:通过类构造函数传递依赖(如管道对象、配置参数),降低模块间耦合度。
6. 输出与可追溯性
- 自动生成文件名:包含提示词(前100字)、种子、时间戳,便于区分不同生成任务。
- 保存目录结构:统一输出到
result/{outdir}
,支持断点续传(exist_ok=True
)。