wan22-service/workers/wan22_wrapper.py

253 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Wan22 - OpenAI-compatible Video Generation Runtime
特征:
1. OpenAI风格返回
2. 严格串行推理锁GPU安全
3. 支持 t2v / i2v / ti2v / s2v
4. 模型常驻内存,跨任务复用
"""
import os
import time
import uuid
import threading
import random
from dataclasses import dataclass
import torch
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
from wan.utils.prompt_extend import QwenPromptExpander
# 全局执行锁(关键)
_GLOBAL_INFER_LOCK = threading.Lock()
@dataclass
class OpenAIVideoResponse:
id: str
object: str
created: int
prompt: str
video_path: str
seed: int
class Wan22:
"""
Wan22 - OpenAI-compatible Video Generation Runtime
"""
def __init__(
self,
ckpt_dir: str,
task: str = "ti2v-5B",
device_id: int = 0,
use_prompt_extend: bool = False,
prompt_extend_model: str = None,
seed: int = -1,
offload_model: bool = True,
):
assert ckpt_dir, "ckpt_dir required"
assert task in WAN_CONFIGS
self.ckpt_dir = ckpt_dir
self.task = task
self.device_id = device_id
self.cfg = WAN_CONFIGS[task]
self.seed = seed if seed >= 0 else random.randint(0, 2**31 - 1)
self.offload_model = offload_model
self.use_prompt_extend = use_prompt_extend
self.prompt_expander = (
QwenPromptExpander(
model_name=prompt_extend_model,
task=task,
is_vl=True,
device=device_id,
)
if use_prompt_extend
else None
)
self.pipeline = self._build_pipeline()
# =========================
# pipeline init
# 注意检查顺序ti2v 必须在 t2v/i2v 之前检查
# 因为 "ti2v" 同时包含 "t2v" 和 "i2v" 子串
# =========================
def _build_pipeline(self):
if "s2v" in self.task:
return wan.WanS2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
if "ti2v" in self.task:
return wan.WanTI2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
if "t2v" in self.task:
return wan.WanT2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
if "i2v" in self.task:
return wan.WanI2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
raise ValueError(self.task)
# =========================
# prompt expand
# =========================
def _expand(self, prompt, image=None):
if not self.use_prompt_extend:
return prompt
out = self.prompt_expander(prompt, image=image)
return out.prompt if out.status else prompt
# =========================
# OpenAI response packer
# =========================
def _pack(self, prompt, video_path, seed):
return {
"id": f"wan_{uuid.uuid4().hex}",
"object": "video.generation",
"created": int(time.time()),
"prompt": prompt,
"video_path": video_path,
"seed": seed,
}
# =========================
# 主入口(统一 + 串行锁)
# 注意 generate() 内的检查顺序必须与 _build_pipeline 一致
# =========================
def generate(self, **kwargs):
"""
OpenAI-style unified entry
"""
with _GLOBAL_INFER_LOCK:
prompt = kwargs.get("prompt")
image_path = kwargs.get("image_path")
prompt = self._expand(
prompt,
image=Image.open(image_path).convert("RGB") if image_path else None,
)
size = kwargs.get("size", "1280*720")
size_cfg = SIZE_CONFIGS[size]
seed = self.seed
if "s2v" in self.task:
video = self.pipeline.generate(
input_prompt=prompt,
ref_image_path=image_path,
audio_path=kwargs.get("audio_path"),
enable_tts=kwargs.get("enable_tts", False),
tts_prompt_audio=kwargs.get("tts_prompt_audio"),
tts_prompt_text=kwargs.get("tts_prompt_text"),
tts_text=kwargs.get("tts_text"),
num_repeat=kwargs.get("num_clip"),
pose_video=kwargs.get("pose_video"),
max_area=MAX_AREA_CONFIGS[size],
infer_frames=kwargs.get("infer_frames", 80),
shift=kwargs.get("shift") or 5.0,
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps") or 30,
guide_scale=kwargs.get("guide_scale") or 5.0,
seed=seed,
offload_model=self.offload_model,
init_first_frame=kwargs.get("start_from_ref", False),
)
elif "ti2v" in self.task:
img = Image.open(image_path).convert("RGB") if image_path else None
video = self.pipeline.generate(
prompt,
img=img,
size=size_cfg,
max_area=MAX_AREA_CONFIGS[size],
frame_num=kwargs.get("frame_num") or 81,
shift=kwargs.get("shift") or 5.0,
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps") or 30,
guide_scale=kwargs.get("guide_scale") or 5.0,
seed=seed,
offload_model=self.offload_model,
)
elif "t2v" in self.task:
video = self.pipeline.generate(
prompt,
size=size_cfg,
frame_num=kwargs.get("frame_num") or 81,
shift=kwargs.get("shift") or 5.0,
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps") or 30,
guide_scale=kwargs.get("guide_scale") or 5.0,
seed=seed,
offload_model=self.offload_model,
)
elif "i2v" in self.task:
img = Image.open(image_path).convert("RGB")
video = self.pipeline.generate(
prompt,
img,
size=size_cfg,
max_area=MAX_AREA_CONFIGS[size],
frame_num=kwargs.get("frame_num") or 81,
shift=kwargs.get("shift") or 5.0,
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps") or 30,
guide_scale=kwargs.get("guide_scale") or 5.0,
seed=seed,
offload_model=self.offload_model,
)
else:
raise ValueError(self.task)
# 保存视频
video_path = kwargs.get(
"save_file", f"/tmp/{uuid.uuid4().hex}.mp4"
)
from wan.utils.utils import save_video
save_video(
tensor=video[None],
save_file=video_path,
fps=self.cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1),
)
return self._pack(prompt, video_path, seed)