""" Wan22 - OpenAI-compatible Video Generation Runtime 特征: 1. OpenAI风格返回 2. 严格串行推理锁(GPU安全) 3. 支持 t2v / i2v / ti2v / s2v 4. 模型常驻内存,跨任务复用 """ import os import time import uuid import threading import random from dataclasses import dataclass import torch from PIL import Image import wan from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS from wan.utils.prompt_extend import QwenPromptExpander # 全局执行锁(关键) _GLOBAL_INFER_LOCK = threading.Lock() @dataclass class OpenAIVideoResponse: id: str object: str created: int prompt: str video_path: str seed: int class Wan22: """ Wan22 - OpenAI-compatible Video Generation Runtime """ def __init__( self, ckpt_dir: str, task: str = "ti2v-5B", device_id: int = 0, use_prompt_extend: bool = False, prompt_extend_model: str = None, seed: int = -1, offload_model: bool = True, ): assert ckpt_dir, "ckpt_dir required" assert task in WAN_CONFIGS self.ckpt_dir = ckpt_dir self.task = task self.device_id = device_id self.cfg = WAN_CONFIGS[task] self.seed = seed if seed >= 0 else random.randint(0, 2**31 - 1) self.offload_model = offload_model self.use_prompt_extend = use_prompt_extend self.prompt_expander = ( QwenPromptExpander( model_name=prompt_extend_model, task=task, is_vl=True, device=device_id, ) if use_prompt_extend else None ) self.pipeline = self._build_pipeline() def _build_pipeline(self): if "t2v" in self.task: return wan.WanT2V( config=self.cfg, checkpoint_dir=self.ckpt_dir, device_id=self.device_id, rank=0, ) if "i2v" in self.task: return wan.WanI2V( config=self.cfg, checkpoint_dir=self.ckpt_dir, device_id=self.device_id, rank=0, ) if "ti2v" in self.task: return wan.WanTI2V( config=self.cfg, checkpoint_dir=self.ckpt_dir, device_id=self.device_id, rank=0, ) if "s2v" in self.task: return wan.WanS2V( config=self.cfg, checkpoint_dir=self.ckpt_dir, device_id=self.device_id, rank=0, ) raise ValueError(self.task) def _expand(self, prompt, image=None): if not self.use_prompt_extend: return prompt out = self.prompt_expander(prompt, image=image) return out.prompt if out.status else prompt def _pack(self, prompt, video_path, seed): return { "id": f"wan_{uuid.uuid4().hex}", "object": "video.generation", "created": int(time.time()), "prompt": prompt, "video_path": video_path, "seed": seed, } def generate(self, **kwargs): """ OpenAI-style unified entry. 全局串行锁保证GPU安全。 """ with _GLOBAL_INFER_LOCK: prompt = kwargs.get("prompt") image_path = kwargs.get("image_path") prompt = self._expand( prompt, image=Image.open(image_path).convert("RGB") if image_path else None, ) size = kwargs.get("size", "1280*720") size_cfg = SIZE_CONFIGS[size] seed = self.seed if "t2v" in self.task: video = self.pipeline.generate( prompt, size=size_cfg, frame_num=kwargs.get("frame_num"), shift=kwargs.get("shift"), sample_solver=kwargs.get("solver", "unipc"), sampling_steps=kwargs.get("steps"), guide_scale=kwargs.get("guide_scale"), seed=seed, offload_model=self.offload_model, ) elif "i2v" in self.task: img = Image.open(image_path).convert("RGB") video = self.pipeline.generate( prompt, img, size=size_cfg, max_area=MAX_AREA_CONFIGS[size], frame_num=kwargs.get("frame_num"), shift=kwargs.get("shift"), sample_solver=kwargs.get("solver", "unipc"), sampling_steps=kwargs.get("steps"), guide_scale=kwargs.get("guide_scale"), seed=seed, offload_model=self.offload_model, ) elif "ti2v" in self.task: img = Image.open(image_path).convert("RGB") video = self.pipeline.generate( prompt, img=img, size=size_cfg, max_area=MAX_AREA_CONFIGS[size], frame_num=kwargs.get("frame_num"), shift=kwargs.get("shift"), sample_solver=kwargs.get("solver", "unipc"), sampling_steps=kwargs.get("steps"), guide_scale=kwargs.get("guide_scale"), seed=seed, offload_model=self.offload_model, ) elif "s2v" in self.task: video = self.pipeline.generate( input_prompt=prompt, ref_image_path=image_path, audio_path=kwargs.get("audio_path"), enable_tts=kwargs.get("enable_tts", False), tts_prompt_audio=kwargs.get("tts_prompt_audio"), tts_prompt_text=kwargs.get("tts_prompt_text"), tts_text=kwargs.get("tts_text"), num_repeat=kwargs.get("num_clip"), pose_video=kwargs.get("pose_video"), max_area=MAX_AREA_CONFIGS[size], infer_frames=kwargs.get("infer_frames", 80), shift=kwargs.get("shift"), sample_solver=kwargs.get("solver", "unipc"), sampling_steps=kwargs.get("steps"), guide_scale=kwargs.get("guide_scale"), seed=seed, offload_model=self.offload_model, init_first_frame=kwargs.get("start_from_ref", False), ) else: raise ValueError(self.task) # 保存视频 video_path = kwargs.get( "save_file", f"/tmp/{uuid.uuid4().hex}.mp4" ) from wan.utils.utils import save_video save_video( tensor=video[None], save_file=video_path, fps=self.cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1), ) return self._pack(prompt, video_path, seed)