- Wan2.2-TI2V-5B GPU 视频推理 - ahserver + longtasks 异步任务队列 - OpenAI 兼容 API: POST /api/submit, GET /api/task, GET /api/status - 模型常驻内存,惰性加载 - 全局串行推理锁(GPU 安全) - 支持 t2v/i2v/ti2v/s2v 四种任务类型
239 lines
7.1 KiB
Python
239 lines
7.1 KiB
Python
"""
|
||
Wan22 - OpenAI-compatible Video Generation Runtime
|
||
|
||
特征:
|
||
1. OpenAI风格返回
|
||
2. 严格串行推理锁(GPU安全)
|
||
3. 支持 t2v / i2v / ti2v / s2v
|
||
4. 模型常驻内存,跨任务复用
|
||
"""
|
||
import os
|
||
import time
|
||
import uuid
|
||
import threading
|
||
import random
|
||
from dataclasses import dataclass
|
||
|
||
import torch
|
||
from PIL import Image
|
||
|
||
import wan
|
||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
|
||
from wan.utils.prompt_extend import QwenPromptExpander
|
||
|
||
|
||
# 全局执行锁(关键)
|
||
_GLOBAL_INFER_LOCK = threading.Lock()
|
||
|
||
|
||
@dataclass
|
||
class OpenAIVideoResponse:
|
||
id: str
|
||
object: str
|
||
created: int
|
||
prompt: str
|
||
video_path: str
|
||
seed: int
|
||
|
||
|
||
class Wan22:
|
||
"""
|
||
Wan22 - OpenAI-compatible Video Generation Runtime
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
ckpt_dir: str,
|
||
task: str = "ti2v-5B",
|
||
device_id: int = 0,
|
||
use_prompt_extend: bool = False,
|
||
prompt_extend_model: str = None,
|
||
seed: int = -1,
|
||
offload_model: bool = True,
|
||
):
|
||
assert ckpt_dir, "ckpt_dir required"
|
||
assert task in WAN_CONFIGS
|
||
|
||
self.ckpt_dir = ckpt_dir
|
||
self.task = task
|
||
self.device_id = device_id
|
||
|
||
self.cfg = WAN_CONFIGS[task]
|
||
|
||
self.seed = seed if seed >= 0 else random.randint(0, 2**31 - 1)
|
||
self.offload_model = offload_model
|
||
|
||
self.use_prompt_extend = use_prompt_extend
|
||
self.prompt_expander = (
|
||
QwenPromptExpander(
|
||
model_name=prompt_extend_model,
|
||
task=task,
|
||
is_vl=True,
|
||
device=device_id,
|
||
)
|
||
if use_prompt_extend
|
||
else None
|
||
)
|
||
|
||
self.pipeline = self._build_pipeline()
|
||
|
||
def _build_pipeline(self):
|
||
if "t2v" in self.task:
|
||
return wan.WanT2V(
|
||
config=self.cfg,
|
||
checkpoint_dir=self.ckpt_dir,
|
||
device_id=self.device_id,
|
||
rank=0,
|
||
)
|
||
|
||
if "i2v" in self.task:
|
||
return wan.WanI2V(
|
||
config=self.cfg,
|
||
checkpoint_dir=self.ckpt_dir,
|
||
device_id=self.device_id,
|
||
rank=0,
|
||
)
|
||
|
||
if "ti2v" in self.task:
|
||
return wan.WanTI2V(
|
||
config=self.cfg,
|
||
checkpoint_dir=self.ckpt_dir,
|
||
device_id=self.device_id,
|
||
rank=0,
|
||
)
|
||
|
||
if "s2v" in self.task:
|
||
return wan.WanS2V(
|
||
config=self.cfg,
|
||
checkpoint_dir=self.ckpt_dir,
|
||
device_id=self.device_id,
|
||
rank=0,
|
||
)
|
||
|
||
raise ValueError(self.task)
|
||
|
||
def _expand(self, prompt, image=None):
|
||
if not self.use_prompt_extend:
|
||
return prompt
|
||
|
||
out = self.prompt_expander(prompt, image=image)
|
||
return out.prompt if out.status else prompt
|
||
|
||
def _pack(self, prompt, video_path, seed):
|
||
return {
|
||
"id": f"wan_{uuid.uuid4().hex}",
|
||
"object": "video.generation",
|
||
"created": int(time.time()),
|
||
"prompt": prompt,
|
||
"video_path": video_path,
|
||
"seed": seed,
|
||
}
|
||
|
||
def generate(self, **kwargs):
|
||
"""
|
||
OpenAI-style unified entry.
|
||
全局串行锁保证GPU安全。
|
||
"""
|
||
with _GLOBAL_INFER_LOCK:
|
||
|
||
prompt = kwargs.get("prompt")
|
||
image_path = kwargs.get("image_path")
|
||
|
||
prompt = self._expand(
|
||
prompt,
|
||
image=Image.open(image_path).convert("RGB") if image_path else None,
|
||
)
|
||
|
||
size = kwargs.get("size", "1280*720")
|
||
size_cfg = SIZE_CONFIGS[size]
|
||
|
||
seed = self.seed
|
||
|
||
if "t2v" in self.task:
|
||
video = self.pipeline.generate(
|
||
prompt,
|
||
size=size_cfg,
|
||
frame_num=kwargs.get("frame_num"),
|
||
shift=kwargs.get("shift"),
|
||
sample_solver=kwargs.get("solver", "unipc"),
|
||
sampling_steps=kwargs.get("steps"),
|
||
guide_scale=kwargs.get("guide_scale"),
|
||
seed=seed,
|
||
offload_model=self.offload_model,
|
||
)
|
||
|
||
elif "i2v" in self.task:
|
||
img = Image.open(image_path).convert("RGB")
|
||
video = self.pipeline.generate(
|
||
prompt,
|
||
img,
|
||
size=size_cfg,
|
||
max_area=MAX_AREA_CONFIGS[size],
|
||
frame_num=kwargs.get("frame_num"),
|
||
shift=kwargs.get("shift"),
|
||
sample_solver=kwargs.get("solver", "unipc"),
|
||
sampling_steps=kwargs.get("steps"),
|
||
guide_scale=kwargs.get("guide_scale"),
|
||
seed=seed,
|
||
offload_model=self.offload_model,
|
||
)
|
||
|
||
elif "ti2v" in self.task:
|
||
img = Image.open(image_path).convert("RGB")
|
||
video = self.pipeline.generate(
|
||
prompt,
|
||
img=img,
|
||
size=size_cfg,
|
||
max_area=MAX_AREA_CONFIGS[size],
|
||
frame_num=kwargs.get("frame_num"),
|
||
shift=kwargs.get("shift"),
|
||
sample_solver=kwargs.get("solver", "unipc"),
|
||
sampling_steps=kwargs.get("steps"),
|
||
guide_scale=kwargs.get("guide_scale"),
|
||
seed=seed,
|
||
offload_model=self.offload_model,
|
||
)
|
||
|
||
elif "s2v" in self.task:
|
||
video = self.pipeline.generate(
|
||
input_prompt=prompt,
|
||
ref_image_path=image_path,
|
||
audio_path=kwargs.get("audio_path"),
|
||
enable_tts=kwargs.get("enable_tts", False),
|
||
tts_prompt_audio=kwargs.get("tts_prompt_audio"),
|
||
tts_prompt_text=kwargs.get("tts_prompt_text"),
|
||
tts_text=kwargs.get("tts_text"),
|
||
num_repeat=kwargs.get("num_clip"),
|
||
pose_video=kwargs.get("pose_video"),
|
||
max_area=MAX_AREA_CONFIGS[size],
|
||
infer_frames=kwargs.get("infer_frames", 80),
|
||
shift=kwargs.get("shift"),
|
||
sample_solver=kwargs.get("solver", "unipc"),
|
||
sampling_steps=kwargs.get("steps"),
|
||
guide_scale=kwargs.get("guide_scale"),
|
||
seed=seed,
|
||
offload_model=self.offload_model,
|
||
init_first_frame=kwargs.get("start_from_ref", False),
|
||
)
|
||
|
||
else:
|
||
raise ValueError(self.task)
|
||
|
||
# 保存视频
|
||
video_path = kwargs.get(
|
||
"save_file", f"/tmp/{uuid.uuid4().hex}.mp4"
|
||
)
|
||
|
||
from wan.utils.utils import save_video
|
||
|
||
save_video(
|
||
tensor=video[None],
|
||
save_file=video_path,
|
||
fps=self.cfg.sample_fps,
|
||
nrow=1,
|
||
normalize=True,
|
||
value_range=(-1, 1),
|
||
)
|
||
|
||
return self._pack(prompt, video_path, seed)
|