commit 769fc4968eb503b4290f5139b88f44d9c6ecc021 Author: yumoqing Date: Tue Jun 9 22:00:22 2026 +0800 feat: initial wan22 video generation service - Wan2.2-TI2V-5B GPU 视频推理 - ahserver + longtasks 异步任务队列 - OpenAI 兼容 API: POST /api/submit, GET /api/task, GET /api/status - 模型常驻内存,惰性加载 - 全局串行推理锁(GPU 安全) - 支持 t2v/i2v/ti2v/s2v 四种任务类型 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..db64852 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*.pyc +__pycache__/ +*.bak +wan22-service.log +py3/ +venv/ +repo/ +logs/ +files/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..b97106f --- /dev/null +++ b/README.md @@ -0,0 +1,177 @@ +# Wan22 Video Generation Service + +Wan2.2-TI2V-5B 视频生成服务,基于 ahserver + longtasks 提供 OpenAI 兼容的异步视频生成 API。 + +## Architecture + +``` +HTTP Request → ahserver (port 8079) → submit.dspy → longtasks.submit_task() + ↓ (Redis Queue) + Wan22Tasks.process_task() + ↓ + Wan22.generate() [GPU] + ↓ + save to /data/ymq/wan22-outputs/ + ↓ + task.dspy ← longtasks.get_status() +``` + +- **串行推理**: GPU 全局锁 `_GLOBAL_INFER_LOCK`,一次只跑一个任务 +- **模型常驻**: 首次任务加载 Wan2.2 模型,后续任务复用,无需重复加载 +- **异步队列**: longtasks 通过 Redis 管理任务队列,支持失败重试 + +## API 接口 + +### 1. 提交视频生成任务 + +``` +POST /api/submit +Content-Type: application/json + +{ + "prompt": "A cinematic scene of...", // 必填,视频描述 + "size": "1280*720", // 可选,默认 1280*720 + "frame_num": 81, // 可选,帧数 (4n+1, 17~129) + "sample_steps": 50, // 可选,采样步数 + "sample_guide_scale": 5.0, // 可选,引导比例 + "base_seed": 42, // 可选,随机种子 + "task_id": "my_custom_id" // 可选,自定义任务ID +} +``` + +**响应**: +```json +{ + "task_id": "a1b2c3d4e5f6", // 用于查询状态 + "status": "queued", + "prompt": "A cinematic scene...", + "size": "1280*720", + "frame_num": 81, + "message": "task submitted", + "check_url": "/api/task?task_id=a1b2c3d4e5f6" +} +``` + +### 2. 查询任务状态 + +``` +GET /api/task?task_id=a1b2c3d4e5f6 +``` + +**响应** (PENDING): +```json +{ + "status": "PENDING", + "created_at": 1712345678.0, + "started_at": null, + "finished_at": null +} +``` + +**响应** (SUCCEEDED): +```json +{ + "status": "SUCCEEDED", + "task_id": "a1b2c3d4e5f6", + "video_url": "/idfile?path=a1b2c3d4e5f6.mp4", + "video_path": "/data/ymq/wan22-outputs/a1b2c3d4e5f6.mp4", + "size": "1280*720", + "frame_num": 81, + "file_size": 12345678, + "prompt": "A cinematic scene...", + "seed": 42, + "created_at": 1712345678.0, + "started_at": 1712345680.0, + "finished_at": 1712345900.0 +} +``` + +**响应** (FAILED): +```json +{ + "status": "FAILED", + "task_id": "a1b2c3d4e5f6", + "error": "CUDA out of memory", + "created_at": 1712345678.0 +} +``` + +### 3. 服务状态 + +``` +GET /api/status +``` + +```json +{ + "service": "wan22-video-generation", + "model": "Wan2.2-TI2V-5B", + "gpu_id": 2, + "gpus": [ + {"id": 0, "util": 23, "mem_used": 5120, "mem_total": 24564}, + {"id": 1, "util": 0, "mem_used": 4, "mem_total": 24564}, + {"id": 2, "util": 45, "mem_used": 8192, "mem_total": 24564} + ] +} +``` + +## 视频下载 + +生成完成后,通过 `video_url` 下载视频: + +``` +GET /idfile?path=a1b2c3d4e5f6.mp4 +``` + +或在浏览器中拼接 URL: +``` +http://:8079/idfile?path=a1b2c3d4e5f6.mp4 +``` + +## 部署 + +```bash +# 启动 +cd ~/wan22-service +WAN22_GPU_ID=2 ./start.sh + +# 停止 +./stop.sh + +# 查看日志 +tail -f wan22-service.log +``` + +环境变量: +- `WAN22_GPU_ID`: GPU 设备号 (默认 2) + +## 文件结构 + +``` +wan22-service/ +├── ah.py # 主入口: ahserver + longtasks 初始化 +├── app/ +│ └── api/ +│ ├── submit/index.dspy # POST /api/submit - 提交任务 +│ ├── task/index.dspy # GET /api/task - 查询状态 +│ └── status/index.dspy # GET /api/status - 服务状态 +├── conf/ +│ └── config.json # ahserver 配置 (端口 8079) +├── workers/ +│ ├── generate.py # 任务执行逻辑 (惰性加载 Wan22) +│ └── wan22_wrapper.py # Wan22 类 (OpenAI 风格封装) +├── repo/ # Wan2.2 推理代码 +├── py3/ # Python venv +├── start.sh / stop.sh +├── skill/ # Hermes skill 文档 +├── README.md +└── wan22-service.log +``` + +## Dependencies + +- ahserver (Web framework) +- longtasks (Async task queue via Redis) +- sqlor (Optional, for database operations) +- torch + torchvision (GPU inference) +- wan (Wan2.2 repo, local at `repo/wan/`) diff --git a/ah.py b/ah.py new file mode 100644 index 0000000..0c75e81 --- /dev/null +++ b/ah.py @@ -0,0 +1,49 @@ +# -*- coding:utf-8 -*- +import os +from ahserver.webapp import webapp +from ahserver.serverenv import ServerEnv +from ahserver.configuredServer import add_startup +from longtasks.longtasks import LongTasks, schedule_once +from appPublic.log import debug + + +class Wan22Tasks(LongTasks): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gpu_id = int(os.environ.get('WAN22_GPU_ID', '2')) + + async def process_task(self, payload: dict, workid: int = None): + import json + if isinstance(payload, str): + payload = json.loads(payload) + task_type = payload.get('task_type', '') + debug(f'Wan22Tasks processing: type={task_type}') + if task_type == 'generate_video': + from workers.generate import run_generate + return await run_generate(self, payload) + raise ValueError(f'Unknown task_type: {task_type}') + + +async def on_app_built(app): + env = ServerEnv() + longtasks = env.longtasks + if longtasks: + schedule_once(0.1, longtasks.run) + debug(f'longtasks worker started, GPU: {longtasks.gpu_id}') + + +def init(): + env = ServerEnv() + longtasks = Wan22Tasks( + 'redis://127.0.0.1:6379', + 'wan22', + worker_cnt=1, + stuck_seconds=3600, + max_age_hours=24 + ) + env.longtasks = longtasks + add_startup(on_app_built) + + +if __name__ == '__main__': + webapp(init) diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/status/index.dspy b/app/api/status/index.dspy new file mode 100644 index 0000000..b14e1a7 --- /dev/null +++ b/app/api/status/index.dspy @@ -0,0 +1,31 @@ +# -*- coding:utf-8 -*- +# GET /api/status - 服务状态 + +import subprocess +import json + +result = { + 'service': 'wan22-video-generation', + 'model': 'Wan2.2-TI2V-5B', + 'gpu_id': 2, + 'gpus': [] +} + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,utilization.gpu,memory.used,memory.total', + '--format=csv,noheader,nounits'], + timeout=5 + ).decode().strip() + for line in out.split('\n'): + parts = [p.strip() for p in line.split(',')] + result['gpus'].append({ + 'id': int(parts[0]), + 'util': int(parts[1]), + 'mem_used': int(parts[2]), + 'mem_total': int(parts[3]) + }) +except Exception: + pass + +return json.dumps(result) diff --git a/app/api/submit/index.dspy b/app/api/submit/index.dspy new file mode 100644 index 0000000..756f3c8 --- /dev/null +++ b/app/api/submit/index.dspy @@ -0,0 +1,71 @@ +# -*- coding:utf-8 -*- +# POST /api/submit - 提交视频生成任务 + +import json +import uuid +from ahserver.serverenv import ServerEnv + +method = request.method + +if method == 'POST': + prompt = params_kw.get('prompt', '') + if not prompt: + return json.dumps({'error': 'prompt is required'}, ensure_ascii=False) + + task_id = params_kw.get('task_id', str(uuid.uuid4()).replace("-", "")[:12]) + image = params_kw.get('image', None) + size = params_kw.get('size', '1280*720') + frame_num = params_kw.get('frame_num', 81) + sample_steps = params_kw.get('sample_steps', None) + sample_guide_scale = params_kw.get('sample_guide_scale', None) + base_seed = params_kw.get('base_seed', None) + + valid_sizes = ['704*1280', '1280*704'] + if size not in valid_sizes: + return json.dumps({'error': f'invalid size, must be one of: {valid_sizes}'}, ensure_ascii=False) + + payload = { + 'task_type': 'generate_video', + 'task_id': task_id, + 'prompt': prompt, + 'image': image, + 'size': size, + 'frame_num': int(frame_num), + 'sample_steps': int(sample_steps) if sample_steps else None, + 'sample_guide_scale': float(sample_guide_scale) if sample_guide_scale else None, + 'base_seed': int(base_seed) if base_seed else None, + } + + env = ServerEnv() + longtasks = env.longtasks + if longtasks is None: + return json.dumps({'error': 'service not ready'}, ensure_ascii=False) + + result = await longtasks.submit_task(payload) + real_task_id = result.get('task_id', str(result)) if isinstance(result, dict) else str(result) + + return json.dumps({ + 'task_id': real_task_id, + 'status': 'queued', + 'prompt': prompt[:100], + 'size': size, + 'frame_num': payload['frame_num'], + 'message': 'task submitted', + 'check_url': f'/api/task?task_id={real_task_id}' + }, ensure_ascii=False) + +else: + return json.dumps({ + 'usage': 'POST with JSON body', + 'params': { + 'prompt': 'string (required)', + 'image': 'string (optional, server path for I2V)', + 'size': 'string (default 1280*720)', + 'frame_num': 'int (default 81, 4n+1, range 17-129)', + 'sample_steps': 'int (optional)', + 'sample_guide_scale': 'float (optional)', + 'base_seed': 'int (optional)', + 'task_id': 'string (optional, auto-generated)', + }, + 'valid_sizes': ['704*1280', '1280*704'] + }, ensure_ascii=False) diff --git a/app/api/task/index.dspy b/app/api/task/index.dspy new file mode 100644 index 0000000..a1137df --- /dev/null +++ b/app/api/task/index.dspy @@ -0,0 +1,17 @@ +# -*- coding:utf-8 -*- +# GET /api/task?task_id=xxx - 查询任务状态 + +import json +from ahserver.serverenv import ServerEnv + +task_id = params_kw.get('task_id', '') +if not task_id: + return json.dumps({'error': 'task_id is required'}, ensure_ascii=False) + +env = ServerEnv() +longtasks = env.longtasks +if longtasks is None: + return json.dumps({'error': 'service not ready'}, ensure_ascii=False) + +status = await longtasks.get_status(task_id) +return json.dumps(status) diff --git a/conf/config.json b/conf/config.json new file mode 100644 index 0000000..cb4edf3 --- /dev/null +++ b/conf/config.json @@ -0,0 +1,29 @@ +{ + "password_key": "Wan22Service2026Key", + "databases": {}, + "session_redis": { + "host": "127.0.0.1", + "port": 6379, + "db": 1 + }, + "website": { + "paths": [ + ["$[workdir]$/app", ""] + ], + "host": "0.0.0.0", + "port": 8079, + "coding": "utf-8", + "indexes": ["index.html", "index.dspy"], + "processors": [ + [".dspy", "dspy"] + ], + "startswiths": [ + { + "leading": "/idfile", + "registerfunction": "idfile" + } + ] + }, + "hot_reload": false, + "filesroot": "/data/ymq/wan22-outputs" +} diff --git a/skill/SKILL.md b/skill/SKILL.md new file mode 100644 index 0000000..0722941 --- /dev/null +++ b/skill/SKILL.md @@ -0,0 +1,87 @@ +--- +name: wan22-video-generation +description: Wan2.2-TI2V-5B 视频生成服务 — OpenAI 兼容 API,基于 ahserver + longtasks 异步任务队列,模型常驻 GPU 内存 +tags: [wan22, video-generation, ai-compute, gpu, ahserver, longtasks] +--- + +# Wan22 Video Generation Service + +Wan2.2-TI2V-5B 视频生成服务,部署在 GPU 服务器 (ymq@opencomputing.net) 上。 + +## 架构 + +``` +User/Hermes → Sage llmage/uapi → wan22-service (port 8079) → GPU 推理 +``` + +独立 ahserver 应用,通过 longtasks + Redis 管理异步视频生成任务。 + +## 关键文件 + +| 文件 | 路径 | 说明 | +|------|------|------| +| 主入口 | `~/wan22-service/ah.py` | ahserver + Wan22Tasks 初始化 | +| 提任务 | `~/wan22-service/app/api/submit/index.dspy` | POST /api/submit | +| 查状态 | `~/wan22-service/app/api/task/index.dspy` | GET /api/task?task_id=xxx | +| 推理执行 | `~/wan22-service/workers/generate.py` | 惰性加载 Wan22,进程内推理 | +| Wan22 类 | `~/wan22-service/workers/wan22_wrapper.py` | OpenAI 风格封装 | +| 配置文件 | `~/wan22-service/conf/config.json` | 端口 8079, Redis, filesroot | +| 启动脚本 | `~/wan22-service/start.sh` | WAN22_GPU_ID=2 | + +## API 接口 + +### 提交任务 +```bash +curl -X POST http://:8079/api/submit \ + -H "Content-Type: application/json" \ + -d '{"prompt":"A cinematic street at dawn, blue-grey tones","size":"1280*720","frame_num":81}' +``` + +### 查询状态 +```bash +curl "http://:8079/api/task?task_id=xxx" +``` + +### 下载视频 +```bash +curl -o output.mp4 "http://:8079/idfile?path=task_id.mp4" +``` + +## 设计要点 + +1. **串行推理锁**: `_GLOBAL_INFER_LOCK` (threading.Lock) 保证 GPU 安全 +2. **模型常驻**: Wan22 实例惰性初始化,首次任务加载后跨任务复用 +3. **异步队列**: longtasks (Redis) worker_cnt=1,一次处理一个任务 +4. **支持任务类型**: t2v / i2v / ti2v / s2v + +## 管理 + +```bash +ssh ymq@opencomputing.net +cd ~/wan22-service +./start.sh # 启动 (后台, nohup) +./stop.sh # 停止 (kill pid) +tail -f wan22-service.log # 查看日志 +``` + +## Sage 集成 + +通过 Sage 的 llmage + uapi 方式接入: + +```sql +-- 注册 uapi provider +INSERT INTO uapiprovider (...) VALUES ('wan22', 'Wan2.2', 'http://wan22.internal:8079'); + +-- 注册 API endpoint +INSERT INTO uapi (providerid, apiname, path, ...) VALUES ('wan22', 'video_generations', '/api/submit', ...); + +-- 注册 llm 模型 +INSERT INTO llm (model, ...) VALUES ('wan2.2-ti2v-5b', ...); +``` + +## 注意事项 + +- GPU OOM 时:减少 frame_num (最小 17) 或换小分辨率 +- task 未完成时返回 `PENDING` 状态,需轮询 +- 任务最长超时 3600 秒 (stuck_seconds) +- 已完成任务保留 24 小时 (max_age_hours) diff --git a/start.sh b/start.sh new file mode 100755 index 0000000..60c3ecd --- /dev/null +++ b/start.sh @@ -0,0 +1,5 @@ +#!/bin/bash +cd ~/wan22-service +export WAN22_GPU_ID=${WAN22_GPU_ID:-2} +nohup /share/vllm-0.8.5/bin/python ah.py > wan22-service.log 2>&1 & +echo "wan22-service started, PID: $!, GPU: $WAN22_GPU_ID" diff --git a/stop.sh b/stop.sh new file mode 100755 index 0000000..8ffabd1 --- /dev/null +++ b/stop.sh @@ -0,0 +1,6 @@ +#!/bin/bash +pkill -f "python ah.py.*wan22" 2>/dev/null +# fallback: kill by port +PID=$(ss -tlnp | grep 8079 | grep -oP 'pid=\K\d+') +[ -n "$PID" ] && kill $PID +echo "wan22-service stopped" diff --git a/workers/__init__.py b/workers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workers/generate.py b/workers/generate.py new file mode 100755 index 0000000..d2441d4 --- /dev/null +++ b/workers/generate.py @@ -0,0 +1,134 @@ +# -*- coding:utf-8 -*- +""" +Wan2.2-TI2V-5B 视频生成 worker(进程内推理,模型常驻内存) + +使用 Wan22 类直接调用推理 pipeline, +替代原先每次任务启动子进程的方式。 +""" +import os +import json +import uuid +import asyncio +from datetime import datetime +from appPublic.log import debug, exception + +OUTPUT_DIR = '/data/ymq/wan22-outputs' +REPO_DIR = '/data/ymq/wan22-service/repo' +MODEL_PATH = '/data/ymq/models/Wan-AI/Wan2.2-TI2V-5B' + +# 全局 Wan22 实例,在 process_task 第一次调用时惰性初始化 +_engine = None + + +def _get_engine(longtasks): + """惰性加载 Wan22 引擎,模型常驻内存。""" + global _engine + if _engine is not None: + return _engine + + debug('Loading Wan22 engine (first call, may take 30-60s)...') + + # 把 repo 加入 sys.path,让 wan 包可导入 + import sys + if REPO_DIR not in sys.path: + sys.path.insert(0, REPO_DIR) + + from workers.wan22_wrapper import Wan22 + + gpu_id = getattr(longtasks, 'gpu_id', int(os.environ.get('WAN22_GPU_ID', '2'))) + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) + + _engine = Wan22( + ckpt_dir=MODEL_PATH, + task='ti2v-5B', + device_id=0, # CUDA_VISIBLE_DEVICES 已隔离,从0开始 + use_prompt_extend=False, + offload_model=True, + ) + + debug(f'Wan22 engine loaded, GPU: {gpu_id}') + return _engine + + +async def run_generate(longtasks, payload): + """ + 执行视频生成(进程内推理)。 + + payload: { + task_id: str, + prompt: str, + image: str (optional), + size: str (default "1280*720"), + frame_num: int (default 81, 4n+1), + sample_steps: int (optional), + sample_guide_scale: float (optional), + base_seed: int (optional), + } + """ + task_id = payload.get('task_id', str(uuid.uuid4())[:12]) + prompt = payload.get('prompt', '') + image_path = payload.get('image', None) + size = payload.get('size', '1280*720') + frame_num = payload.get('frame_num', 81) + sample_steps = payload.get('sample_steps', None) + sample_guide_scale = payload.get('sample_guide_scale', None) + base_seed = payload.get('base_seed', None) + + # 校验 frame_num (4n+1) + frame_num = max(17, min(frame_num, 129)) + if (frame_num - 1) % 4 != 0: + frame_num = ((frame_num - 1) // 4) * 4 + 1 + + os.makedirs(OUTPUT_DIR, exist_ok=True) + output_file = os.path.join(OUTPUT_DIR, f'{task_id}.mp4') + + try: + # 惰性加载引擎(模型常驻,后续任务复用) + engine = _get_engine(longtasks) + + # 在 executor 中运行同步推理(不阻塞 asyncio 事件循环) + loop = asyncio.get_running_loop() + + def _infer(): + return engine.generate( + prompt=prompt, + image_path=image_path, + size=size, + frame_num=frame_num, + steps=sample_steps, + guide_scale=sample_guide_scale, + seed=base_seed if base_seed is not None else engine.seed, + save_file=output_file, + ) + + result = await loop.run_in_executor(None, _infer) + + if not os.path.exists(output_file): + return { + 'task_id': task_id, + 'status': 'failed', + 'error': 'Output file not created by engine', + } + + file_size = os.path.getsize(output_file) + debug(f'Video generated: {output_file} ({file_size} bytes)') + + return { + 'task_id': task_id, + 'status': 'completed', + 'video_url': f'/idfile?path={task_id}.mp4', + 'video_path': output_file, + 'size': size, + 'frame_num': frame_num, + 'file_size': file_size, + 'prompt': prompt[:100], + 'seed': result.get('seed'), + } + + except Exception as e: + exception(f'Generation error: {e}') + return { + 'task_id': task_id, + 'status': 'failed', + 'error': str(e), + } diff --git a/workers/generate_subprocess.py.bak b/workers/generate_subprocess.py.bak new file mode 100755 index 0000000..c7327a9 --- /dev/null +++ b/workers/generate_subprocess.py.bak @@ -0,0 +1,131 @@ +# -*- coding:utf-8 -*- +""" +Wan2.2-TI2V-5B 视频生成 worker +调用官方 generate.py 脚本 +""" +import os +import json +import uuid +import asyncio +import subprocess +from datetime import datetime +from appPublic.log import debug, exception + +OUTPUT_DIR = '/data/ymq/wan22-outputs' +REPO_DIR = '/data/ymq/wan22-service/repo' +MODEL_PATH = '/data/ymq/models/Wan-AI/Wan2.2-TI2V-5B' +PYTHON = '/share/vllm-0.8.5/bin/python' + + +async def run_generate(longtasks, payload): + """ + Execute video generation via generate.py subprocess. + + payload: { + task_id: str, + prompt: str, + image: str (optional path for i2v), + size: str (default "1280*720"), + frame_num: int (default 81, must be 4n+1), + sample_steps: int (optional), + sample_guide_scale: float (optional), + base_seed: int (optional), + } + """ + task_id = payload.get('task_id', str(uuid.uuid4())[:12]) + prompt = payload.get('prompt', '') + image = payload.get('image', None) + size = payload.get('size', '1280*720') + frame_num = payload.get('frame_num', 81) + sample_steps = payload.get('sample_steps', None) + sample_guide_scale = payload.get('sample_guide_scale', None) + base_seed = payload.get('base_seed', None) + + # Ensure frame_num is 4n+1 + frame_num = max(17, min(frame_num, 129)) + if (frame_num - 1) % 4 != 0: + frame_num = ((frame_num - 1) // 4) * 4 + 1 + + os.makedirs(OUTPUT_DIR, exist_ok=True) + output_file = os.path.join(OUTPUT_DIR, f'{task_id}.mp4') + + # Build command + cmd = [ + PYTHON, 'generate.py', + '--task', 'ti2v-5B', + '--ckpt_dir', MODEL_PATH, + '--size', size, + '--frame_num', str(frame_num), + '--prompt', prompt, + '--save_file', output_file, + '--offload_model', 'True', + ] + + if image: + cmd.extend(['--image', image]) + + if sample_steps: + cmd.extend(['--sample_steps', str(sample_steps)]) + + if sample_guide_scale: + cmd.extend(['--sample_guide_scale', str(sample_guide_scale)]) + + if base_seed is not None: + cmd.extend(['--base_seed', str(base_seed)]) + + # Set CUDA_VISIBLE_DEVICES for single GPU + gpu_id = longtasks.gpu_id if longtasks.gpu_id else 2 + env = os.environ.copy() + env['CUDA_VISIBLE_DEVICES'] = str(gpu_id) + + debug(f'Running: {" ".join(cmd)}') + + try: + # Run in subprocess + proc = await asyncio.create_subprocess_exec( + *cmd, + cwd=REPO_DIR, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + error_msg = stderr.decode('utf-8', errors='ignore')[-500:] + exception(f'generate.py failed: {error_msg}') + return { + 'task_id': task_id, + 'status': 'failed', + 'error': error_msg + } + + # Check output file + if not os.path.exists(output_file): + return { + 'task_id': task_id, + 'status': 'failed', + 'error': 'Output file not created' + } + + file_size = os.path.getsize(output_file) + + return { + 'task_id': task_id, + 'status': 'completed', + 'video_url': f'/idfile?path={task_id}.mp4', + 'video_path': output_file, + 'size': size, + 'frame_num': frame_num, + 'file_size': file_size, + 'prompt': prompt[:100] + } + + except Exception as e: + exception(f'Generation error: {e}') + return { + 'task_id': task_id, + 'status': 'failed', + 'error': str(e) + } diff --git a/workers/wan22_wrapper.py b/workers/wan22_wrapper.py new file mode 100644 index 0000000..903d839 --- /dev/null +++ b/workers/wan22_wrapper.py @@ -0,0 +1,238 @@ +""" +Wan22 - OpenAI-compatible Video Generation Runtime + +特征: +1. OpenAI风格返回 +2. 严格串行推理锁(GPU安全) +3. 支持 t2v / i2v / ti2v / s2v +4. 模型常驻内存,跨任务复用 +""" +import os +import time +import uuid +import threading +import random +from dataclasses import dataclass + +import torch +from PIL import Image + +import wan +from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS +from wan.utils.prompt_extend import QwenPromptExpander + + +# 全局执行锁(关键) +_GLOBAL_INFER_LOCK = threading.Lock() + + +@dataclass +class OpenAIVideoResponse: + id: str + object: str + created: int + prompt: str + video_path: str + seed: int + + +class Wan22: + """ + Wan22 - OpenAI-compatible Video Generation Runtime + """ + + def __init__( + self, + ckpt_dir: str, + task: str = "ti2v-5B", + device_id: int = 0, + use_prompt_extend: bool = False, + prompt_extend_model: str = None, + seed: int = -1, + offload_model: bool = True, + ): + assert ckpt_dir, "ckpt_dir required" + assert task in WAN_CONFIGS + + self.ckpt_dir = ckpt_dir + self.task = task + self.device_id = device_id + + self.cfg = WAN_CONFIGS[task] + + self.seed = seed if seed >= 0 else random.randint(0, 2**31 - 1) + self.offload_model = offload_model + + self.use_prompt_extend = use_prompt_extend + self.prompt_expander = ( + QwenPromptExpander( + model_name=prompt_extend_model, + task=task, + is_vl=True, + device=device_id, + ) + if use_prompt_extend + else None + ) + + self.pipeline = self._build_pipeline() + + def _build_pipeline(self): + if "t2v" in self.task: + return wan.WanT2V( + config=self.cfg, + checkpoint_dir=self.ckpt_dir, + device_id=self.device_id, + rank=0, + ) + + if "i2v" in self.task: + return wan.WanI2V( + config=self.cfg, + checkpoint_dir=self.ckpt_dir, + device_id=self.device_id, + rank=0, + ) + + if "ti2v" in self.task: + return wan.WanTI2V( + config=self.cfg, + checkpoint_dir=self.ckpt_dir, + device_id=self.device_id, + rank=0, + ) + + if "s2v" in self.task: + return wan.WanS2V( + config=self.cfg, + checkpoint_dir=self.ckpt_dir, + device_id=self.device_id, + rank=0, + ) + + raise ValueError(self.task) + + def _expand(self, prompt, image=None): + if not self.use_prompt_extend: + return prompt + + out = self.prompt_expander(prompt, image=image) + return out.prompt if out.status else prompt + + def _pack(self, prompt, video_path, seed): + return { + "id": f"wan_{uuid.uuid4().hex}", + "object": "video.generation", + "created": int(time.time()), + "prompt": prompt, + "video_path": video_path, + "seed": seed, + } + + def generate(self, **kwargs): + """ + OpenAI-style unified entry. + 全局串行锁保证GPU安全。 + """ + with _GLOBAL_INFER_LOCK: + + prompt = kwargs.get("prompt") + image_path = kwargs.get("image_path") + + prompt = self._expand( + prompt, + image=Image.open(image_path).convert("RGB") if image_path else None, + ) + + size = kwargs.get("size", "1280*720") + size_cfg = SIZE_CONFIGS[size] + + seed = self.seed + + if "t2v" in self.task: + video = self.pipeline.generate( + prompt, + size=size_cfg, + frame_num=kwargs.get("frame_num"), + shift=kwargs.get("shift"), + sample_solver=kwargs.get("solver", "unipc"), + sampling_steps=kwargs.get("steps"), + guide_scale=kwargs.get("guide_scale"), + seed=seed, + offload_model=self.offload_model, + ) + + elif "i2v" in self.task: + img = Image.open(image_path).convert("RGB") + video = self.pipeline.generate( + prompt, + img, + size=size_cfg, + max_area=MAX_AREA_CONFIGS[size], + frame_num=kwargs.get("frame_num"), + shift=kwargs.get("shift"), + sample_solver=kwargs.get("solver", "unipc"), + sampling_steps=kwargs.get("steps"), + guide_scale=kwargs.get("guide_scale"), + seed=seed, + offload_model=self.offload_model, + ) + + elif "ti2v" in self.task: + img = Image.open(image_path).convert("RGB") + video = self.pipeline.generate( + prompt, + img=img, + size=size_cfg, + max_area=MAX_AREA_CONFIGS[size], + frame_num=kwargs.get("frame_num"), + shift=kwargs.get("shift"), + sample_solver=kwargs.get("solver", "unipc"), + sampling_steps=kwargs.get("steps"), + guide_scale=kwargs.get("guide_scale"), + seed=seed, + offload_model=self.offload_model, + ) + + elif "s2v" in self.task: + video = self.pipeline.generate( + input_prompt=prompt, + ref_image_path=image_path, + audio_path=kwargs.get("audio_path"), + enable_tts=kwargs.get("enable_tts", False), + tts_prompt_audio=kwargs.get("tts_prompt_audio"), + tts_prompt_text=kwargs.get("tts_prompt_text"), + tts_text=kwargs.get("tts_text"), + num_repeat=kwargs.get("num_clip"), + pose_video=kwargs.get("pose_video"), + max_area=MAX_AREA_CONFIGS[size], + infer_frames=kwargs.get("infer_frames", 80), + shift=kwargs.get("shift"), + sample_solver=kwargs.get("solver", "unipc"), + sampling_steps=kwargs.get("steps"), + guide_scale=kwargs.get("guide_scale"), + seed=seed, + offload_model=self.offload_model, + init_first_frame=kwargs.get("start_from_ref", False), + ) + + else: + raise ValueError(self.task) + + # 保存视频 + video_path = kwargs.get( + "save_file", f"/tmp/{uuid.uuid4().hex}.mp4" + ) + + from wan.utils.utils import save_video + + save_video( + tensor=video[None], + save_file=video_path, + fps=self.cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1), + ) + + return self._pack(prompt, video_path, seed)