realesrgan-service/workers/wan22_wrapper.py
yumoqing 769fc4968e feat: initial wan22 video generation service
- Wan2.2-TI2V-5B GPU 视频推理
- ahserver + longtasks 异步任务队列
- OpenAI 兼容 API: POST /api/submit, GET /api/task, GET /api/status
- 模型常驻内存,惰性加载
- 全局串行推理锁(GPU 安全)
- 支持 t2v/i2v/ti2v/s2v 四种任务类型
2026-06-09 22:00:22 +08:00

239 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Wan22 - OpenAI-compatible Video Generation Runtime
特征:
1. OpenAI风格返回
2. 严格串行推理锁GPU安全
3. 支持 t2v / i2v / ti2v / s2v
4. 模型常驻内存,跨任务复用
"""
import os
import time
import uuid
import threading
import random
from dataclasses import dataclass
import torch
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
from wan.utils.prompt_extend import QwenPromptExpander
# 全局执行锁(关键)
_GLOBAL_INFER_LOCK = threading.Lock()
@dataclass
class OpenAIVideoResponse:
id: str
object: str
created: int
prompt: str
video_path: str
seed: int
class Wan22:
"""
Wan22 - OpenAI-compatible Video Generation Runtime
"""
def __init__(
self,
ckpt_dir: str,
task: str = "ti2v-5B",
device_id: int = 0,
use_prompt_extend: bool = False,
prompt_extend_model: str = None,
seed: int = -1,
offload_model: bool = True,
):
assert ckpt_dir, "ckpt_dir required"
assert task in WAN_CONFIGS
self.ckpt_dir = ckpt_dir
self.task = task
self.device_id = device_id
self.cfg = WAN_CONFIGS[task]
self.seed = seed if seed >= 0 else random.randint(0, 2**31 - 1)
self.offload_model = offload_model
self.use_prompt_extend = use_prompt_extend
self.prompt_expander = (
QwenPromptExpander(
model_name=prompt_extend_model,
task=task,
is_vl=True,
device=device_id,
)
if use_prompt_extend
else None
)
self.pipeline = self._build_pipeline()
def _build_pipeline(self):
if "t2v" in self.task:
return wan.WanT2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
if "i2v" in self.task:
return wan.WanI2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
if "ti2v" in self.task:
return wan.WanTI2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
if "s2v" in self.task:
return wan.WanS2V(
config=self.cfg,
checkpoint_dir=self.ckpt_dir,
device_id=self.device_id,
rank=0,
)
raise ValueError(self.task)
def _expand(self, prompt, image=None):
if not self.use_prompt_extend:
return prompt
out = self.prompt_expander(prompt, image=image)
return out.prompt if out.status else prompt
def _pack(self, prompt, video_path, seed):
return {
"id": f"wan_{uuid.uuid4().hex}",
"object": "video.generation",
"created": int(time.time()),
"prompt": prompt,
"video_path": video_path,
"seed": seed,
}
def generate(self, **kwargs):
"""
OpenAI-style unified entry.
全局串行锁保证GPU安全。
"""
with _GLOBAL_INFER_LOCK:
prompt = kwargs.get("prompt")
image_path = kwargs.get("image_path")
prompt = self._expand(
prompt,
image=Image.open(image_path).convert("RGB") if image_path else None,
)
size = kwargs.get("size", "1280*720")
size_cfg = SIZE_CONFIGS[size]
seed = self.seed
if "t2v" in self.task:
video = self.pipeline.generate(
prompt,
size=size_cfg,
frame_num=kwargs.get("frame_num"),
shift=kwargs.get("shift"),
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps"),
guide_scale=kwargs.get("guide_scale"),
seed=seed,
offload_model=self.offload_model,
)
elif "i2v" in self.task:
img = Image.open(image_path).convert("RGB")
video = self.pipeline.generate(
prompt,
img,
size=size_cfg,
max_area=MAX_AREA_CONFIGS[size],
frame_num=kwargs.get("frame_num"),
shift=kwargs.get("shift"),
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps"),
guide_scale=kwargs.get("guide_scale"),
seed=seed,
offload_model=self.offload_model,
)
elif "ti2v" in self.task:
img = Image.open(image_path).convert("RGB")
video = self.pipeline.generate(
prompt,
img=img,
size=size_cfg,
max_area=MAX_AREA_CONFIGS[size],
frame_num=kwargs.get("frame_num"),
shift=kwargs.get("shift"),
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps"),
guide_scale=kwargs.get("guide_scale"),
seed=seed,
offload_model=self.offload_model,
)
elif "s2v" in self.task:
video = self.pipeline.generate(
input_prompt=prompt,
ref_image_path=image_path,
audio_path=kwargs.get("audio_path"),
enable_tts=kwargs.get("enable_tts", False),
tts_prompt_audio=kwargs.get("tts_prompt_audio"),
tts_prompt_text=kwargs.get("tts_prompt_text"),
tts_text=kwargs.get("tts_text"),
num_repeat=kwargs.get("num_clip"),
pose_video=kwargs.get("pose_video"),
max_area=MAX_AREA_CONFIGS[size],
infer_frames=kwargs.get("infer_frames", 80),
shift=kwargs.get("shift"),
sample_solver=kwargs.get("solver", "unipc"),
sampling_steps=kwargs.get("steps"),
guide_scale=kwargs.get("guide_scale"),
seed=seed,
offload_model=self.offload_model,
init_first_frame=kwargs.get("start_from_ref", False),
)
else:
raise ValueError(self.task)
# 保存视频
video_path = kwargs.get(
"save_file", f"/tmp/{uuid.uuid4().hex}.mp4"
)
from wan.utils.utils import save_video
save_video(
tensor=video[None],
save_file=video_path,
fps=self.cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1),
)
return self._pack(prompt, video_path, seed)