Add standard API endpoints (status/submit/task) and upscale worker

This commit is contained in:
yumoqing 2026-06-14 16:25:38 +08:00
parent 769fc4968e
commit e43159f875
3 changed files with 145 additions and 38 deletions

View File

@ -1,13 +1,14 @@
# -*- coding:utf-8 -*-
# GET /api/status - 服务状态
# GET /api/status - Real-ESRGAN服务状态
import subprocess
import json
result = {
'service': 'wan22-video-generation',
'model': 'Wan2.2-TI2V-5B',
'gpu_id': 2,
'service': 'realesrgan-video-upscale',
'model': 'RealESRGAN_x2plus',
'gpu_id': 5,
'scale_factor': 2,
'gpus': []
}

View File

@ -1,5 +1,5 @@
# -*- coding:utf-8 -*-
# POST /api/submit - 提交视频生成任务
# POST /api/submit - 提交视频超分任务
import json
import uuid
@ -8,32 +8,20 @@ from ahserver.serverenv import ServerEnv
method = request.method
if method == 'POST':
prompt = params_kw.get('prompt', '')
if not prompt:
return json.dumps({'error': 'prompt is required'}, ensure_ascii=False)
video_path = params_kw.get('video_path', '')
if not video_path:
return json.dumps({'error': 'video_path is required'}, ensure_ascii=False)
task_id = params_kw.get('task_id', str(uuid.uuid4()).replace("-", "")[:12])
image = params_kw.get('image', None)
size = params_kw.get('size', '1280*720')
frame_num = params_kw.get('frame_num', 81)
sample_steps = params_kw.get('sample_steps', None)
sample_guide_scale = params_kw.get('sample_guide_scale', None)
base_seed = params_kw.get('base_seed', None)
valid_sizes = ['704*1280', '1280*704']
if size not in valid_sizes:
return json.dumps({'error': f'invalid size, must be one of: {valid_sizes}'}, ensure_ascii=False)
scale = params_kw.get('scale', 2)
output_format = params_kw.get('output_format', 'mp4')
payload = {
'task_type': 'generate_video',
'task_type': 'upscale_video',
'task_id': task_id,
'prompt': prompt,
'image': image,
'size': size,
'frame_num': int(frame_num),
'sample_steps': int(sample_steps) if sample_steps else None,
'sample_guide_scale': float(sample_guide_scale) if sample_guide_scale else None,
'base_seed': int(base_seed) if base_seed else None,
'video_path': video_path,
'scale': int(scale),
'output_format': output_format
}
env = ServerEnv()
@ -47,9 +35,8 @@ if method == 'POST':
return json.dumps({
'task_id': real_task_id,
'status': 'queued',
'prompt': prompt[:100],
'size': size,
'frame_num': payload['frame_num'],
'video_path': video_path,
'scale': int(scale),
'message': 'task submitted',
'check_url': f'/api/task?task_id={real_task_id}'
}, ensure_ascii=False)
@ -58,14 +45,9 @@ else:
return json.dumps({
'usage': 'POST with JSON body',
'params': {
'prompt': 'string (required)',
'image': 'string (optional, server path for I2V)',
'size': 'string (default 1280*720)',
'frame_num': 'int (default 81, 4n+1, range 17-129)',
'sample_steps': 'int (optional)',
'sample_guide_scale': 'float (optional)',
'base_seed': 'int (optional)',
'video_path': 'string (required, server path to video file)',
'scale': 'int (default 2, upscale factor)',
'output_format': 'string (default mp4)',
'task_id': 'string (optional, auto-generated)',
},
'valid_sizes': ['704*1280', '1280*704']
}
}, ensure_ascii=False)

124
workers/upscale.py Normal file
View File

@ -0,0 +1,124 @@
import os
import asyncio
import cv2
import shutil
import subprocess
from pathlib import Path
from appPublic.log import debug
def _load_model(tasks):
"""Lazy-load Real-ESRGAN model (stays in VRAM)"""
if tasks.upsampler is not None:
return
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model_path = "/data/ymq/models/RealESRGAN_x2plus.pth"
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
subprocess.run([
"wget", "-q",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
"-O", model_path
], check=True)
device = "cuda:0"
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
tasks.upsampler = RealESRGANer(
scale=2,
model_path=model_path,
model=model,
tile=0, tile_pad=10, pre_pad=0,
half=False, device=device
)
debug(f"Real-ESRGAN x2 model loaded on {device}")
async def run_upscale(tasks, payload: dict):
"""Upscale a video file using Real-ESRGAN"""
video_path = payload.get("video_path", "")
output_dir = payload.get("output_dir", "/data/ymq/upscaled-outputs")
scale = payload.get("scale", 2)
if not video_path or not os.path.exists(video_path):
return {"task_id": payload.get("task_id", ""), "status": "failed", "error": f"Video not found: {video_path}"}
# Load model on first call
_load_model(tasks)
task_id = payload.get("task_id", "unknown")
task_dir = Path(f"/tmp/realesrgan_{task_id}")
frames_dir = task_dir / "frames"
upscaled_dir = task_dir / "upscaled"
try:
for d in [task_dir, frames_dir, upscaled_dir]:
d.mkdir(parents=True, exist_ok=True)
# Extract frames
debug(f"[{task_id}] Extracting frames...")
subprocess.run([
"ffmpeg", "-i", video_path, "-q:v", "1",
str(frames_dir / "frame_%04d.jpg")
], check=True, capture_output=True)
# Get video fps
result = subprocess.run([
"ffprobe", "-v", "error", "-select_streams", "v:0",
"-show_entries", "stream=r_frame_rate",
"-of", "csv=p=0", video_path
], check=True, capture_output=True, text=True)
fps = result.stdout.strip()
# Upscale frames
frame_files = sorted(frames_dir.glob("*.jpg"))
total = len(frame_files)
debug(f"[{task_id}] Upscaling {total} frames...")
for i, frame_path in enumerate(frame_files, 1):
img = cv2.imread(str(frame_path), cv2.IMREAD_UNCHANGED)
output, _ = tasks.upsampler.enhance(img, outscale=scale)
cv2.imwrite(str(upscaled_dir / frame_path.name), output)
if i % 10 == 0 or i == total:
debug(f"[{task_id}] Frame {i}/{total}")
# Re-encode video
os.makedirs(output_dir, exist_ok=True)
stem = Path(video_path).stem
output_video = Path(output_dir) / f"{stem}_upscaled.mp4"
debug(f"[{task_id}] Re-encoding video...")
subprocess.run([
"ffmpeg", "-framerate", fps,
"-i", str(upscaled_dir / "frame_%04d.jpg"),
"-c:v", "libx264", "-preset", "slow", "-crf", "18",
"-pix_fmt", "yuv420p", str(output_video), "-y"
], check=True, capture_output=True)
# Get output resolution
result = subprocess.run([
"ffprobe", "-v", "error", "-select_streams", "v:0",
"-show_entries", "stream=width,height",
"-of", "csv=p=0", str(output_video)
], check=True, capture_output=True, text=True)
w, h = result.stdout.strip().split(",")
debug(f"[{task_id}] Done: {output_video} ({w}x{h})")
return {
"task_id": task_id,
"status": "success",
"video_path": str(output_video),
"resolution": f"{w}x{h}",
"frames": total
}
except Exception as e:
debug(f"[{task_id}] Error: {e}")
return {"task_id": task_id, "status": "failed", "error": str(e)}
finally:
if task_dir.exists():
shutil.rmtree(task_dir)