362 lines
12 KiB
Python
362 lines
12 KiB
Python
"""Pipeline Service - aiohttp HTTP API server.
|
|
|
|
Endpoints:
|
|
POST /v1/task/submit - Create and start a new pipeline
|
|
GET /v1/tasks - List pipelines for a user
|
|
GET /v1/task/{id} - Get pipeline detail with node tree
|
|
GET /v1/task/{id}/node/{step} - Get node input/output
|
|
POST /v1/task/update - Modify node artifact + cascade rerun
|
|
POST /v1/task/{id}/pause - Pause a running pipeline
|
|
POST /v1/task/{id}/resume - Resume a paused pipeline
|
|
GET /v1/health - Health check
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from aiohttp import web
|
|
|
|
from .state import (
|
|
PIPELINE_SUBMITTED, PIPELINE_RUNNING, PIPELINE_COMPLETED, PIPELINE_FAILED, PIPELINE_PAUSED,
|
|
get_cascade_rerun_steps, get_rerun_from_next, build_dependency_map,
|
|
)
|
|
from .storage import (
|
|
create_pipeline, get_manifest, save_manifest,
|
|
get_artifact, save_artifact, get_all_artifacts,
|
|
create_new_version, reset_steps, get_user_pipelines,
|
|
)
|
|
from .executor import start_pipeline, resume_pipeline, stop_pipeline, is_running
|
|
|
|
logger = logging.getLogger("pipeline.api")
|
|
|
|
JSON_HEADERS = {"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
|
|
def json_response(data, status=200):
|
|
return web.Response(
|
|
text=json.dumps(data, ensure_ascii=False),
|
|
status=status,
|
|
headers=JSON_HEADERS,
|
|
)
|
|
|
|
|
|
def error_response(message, status=400):
|
|
return json_response({"status": "error", "message": message}, status=status)
|
|
|
|
|
|
# === Health ===
|
|
|
|
async def handle_health(request):
|
|
return json_response({"status": "ok", "service": "pipeline-service", "version": "1.0.0"})
|
|
|
|
|
|
# === Submit New Pipeline ===
|
|
|
|
async def handle_submit(request):
|
|
"""POST /v1/task/submit
|
|
Body: {mode, title, lyrics?, input_audio?, input_video?, user_id, ...}
|
|
"""
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
return error_response("Invalid JSON body")
|
|
|
|
mode = body.get("mode")
|
|
title = body.get("title")
|
|
user_id = body.get("user_id")
|
|
|
|
if not mode:
|
|
return error_response("Missing required field: mode")
|
|
if not title:
|
|
return error_response("Missing required field: title")
|
|
if not user_id:
|
|
return error_response("Missing required field: user_id")
|
|
|
|
# Validate mode
|
|
valid_modes = ["audio_lyrics", "video_lyrics", "lyrics_only"]
|
|
if mode not in valid_modes:
|
|
return error_response(f"Invalid mode: {mode}. Must be one of: {valid_modes}")
|
|
|
|
# Validate mode-specific requirements
|
|
if mode == "audio_lyrics" and not body.get("input_audio"):
|
|
return error_response("Mode audio_lyrics requires input_audio")
|
|
if mode == "video_lyrics" and not body.get("input_video"):
|
|
return error_response("Mode video_lyrics requires input_video")
|
|
|
|
# Extract params (everything except mode/title/user_id)
|
|
params = {k: v for k, v in body.items() if k not in ("mode", "title", "user_id")}
|
|
|
|
try:
|
|
manifest = create_pipeline(user_id, mode, title, params)
|
|
except Exception as e:
|
|
return error_response(f"Failed to create pipeline: {e}", 500)
|
|
|
|
pipeline_id = manifest["pipeline_id"]
|
|
|
|
# Start execution in background
|
|
await start_pipeline(pipeline_id)
|
|
|
|
return json_response({
|
|
"status": "ok",
|
|
"pipeline_id": pipeline_id,
|
|
"mode": mode,
|
|
"title": title,
|
|
"message": f"Pipeline created and started",
|
|
})
|
|
|
|
|
|
# === List User Pipelines ===
|
|
|
|
async def handle_list(request):
|
|
"""GET /v1/tasks?user_id=xxx"""
|
|
user_id = request.query.get("user_id")
|
|
if not user_id:
|
|
return error_response("Missing query param: user_id")
|
|
|
|
pipeline_ids = get_user_pipelines(user_id)
|
|
tasks = []
|
|
for pid in pipeline_ids:
|
|
manifest = get_manifest(pid)
|
|
if manifest:
|
|
tasks.append({
|
|
"pipeline_id": manifest["pipeline_id"],
|
|
"mode": manifest["mode"],
|
|
"title": manifest["title"],
|
|
"state": manifest["state"],
|
|
"current_version": manifest["current_version"],
|
|
"created_at": manifest["created_at"],
|
|
"updated_at": manifest["updated_at"],
|
|
})
|
|
|
|
return json_response({"status": "ok", "tasks": tasks, "total": len(tasks)})
|
|
|
|
|
|
# === Pipeline Detail ===
|
|
|
|
async def handle_detail(request):
|
|
"""GET /v1/task/{id}"""
|
|
pipeline_id = request.match_info["id"]
|
|
manifest = get_manifest(pipeline_id)
|
|
if not manifest:
|
|
return error_response(f"Pipeline not found: {pipeline_id}", 404)
|
|
|
|
version = manifest["current_version"]
|
|
artifacts = get_all_artifacts(pipeline_id, version)
|
|
|
|
# Build artifact summary
|
|
artifact_summary = {}
|
|
for key, art in artifacts.items():
|
|
artifact_summary[key] = {
|
|
"step": art.get("step"),
|
|
"type": art.get("type"),
|
|
"version": art.get("version"),
|
|
"saved_at": art.get("saved_at"),
|
|
}
|
|
|
|
return json_response({
|
|
"status": "ok",
|
|
"pipeline_id": manifest["pipeline_id"],
|
|
"user_id": manifest["user_id"],
|
|
"mode": manifest["mode"],
|
|
"title": manifest["title"],
|
|
"state": manifest["state"],
|
|
"current_version": version,
|
|
"created_at": manifest["created_at"],
|
|
"updated_at": manifest["updated_at"],
|
|
"steps": manifest["steps"],
|
|
"versions": manifest["versions"],
|
|
"artifacts": artifact_summary,
|
|
"is_running": is_running(pipeline_id),
|
|
})
|
|
|
|
|
|
# === Node Input/Output ===
|
|
|
|
async def handle_node(request):
|
|
"""GET /v1/task/{id}/node/{step}?version=N"""
|
|
pipeline_id = request.match_info["id"]
|
|
step = request.match_info["step"]
|
|
version = request.query.get("version")
|
|
|
|
manifest = get_manifest(pipeline_id)
|
|
if not manifest:
|
|
return error_response(f"Pipeline not found: {pipeline_id}", 404)
|
|
|
|
if step not in manifest["steps"]:
|
|
return error_response(f"Step not found: {step}")
|
|
|
|
v = int(version) if version else manifest["current_version"]
|
|
|
|
input_art = get_artifact(pipeline_id, v, step, "input")
|
|
output_art = get_artifact(pipeline_id, v, step, "output")
|
|
|
|
step_info = manifest["steps"][step]
|
|
|
|
return json_response({
|
|
"status": "ok",
|
|
"pipeline_id": pipeline_id,
|
|
"step": step,
|
|
"display_name": step_info.get("display_name", step),
|
|
"version": v,
|
|
"state": step_info["state"],
|
|
"input": input_art["data"] if input_art else None,
|
|
"output": output_art["data"] if output_art else None,
|
|
"input_version": input_art.get("version") if input_art else None,
|
|
"output_version": output_art.get("version") if output_art else None,
|
|
})
|
|
|
|
|
|
# === Modify Node + Cascade Rerun ===
|
|
|
|
async def handle_update(request):
|
|
"""POST /v1/task/update
|
|
Body: {pipeline_id, updates: {step: {content: ...}}, rerun_from: "node"|"next"}
|
|
|
|
rerun_from:
|
|
- "node": input changed, rerun from this step
|
|
- "next": output changed, rerun from next step(s)
|
|
"""
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
return error_response("Invalid JSON body")
|
|
|
|
pipeline_id = body.get("pipeline_id")
|
|
updates = body.get("updates", {})
|
|
rerun_from = body.get("rerun_from", "node")
|
|
|
|
if not pipeline_id:
|
|
return error_response("Missing pipeline_id")
|
|
if not updates:
|
|
return error_response("Missing or empty updates")
|
|
|
|
manifest = get_manifest(pipeline_id)
|
|
if not manifest:
|
|
return error_response(f"Pipeline not found: {pipeline_id}", 404)
|
|
|
|
if manifest["state"] in (PIPELINE_RUNNING,):
|
|
if is_running(pipeline_id):
|
|
return error_response("Pipeline is currently running. Pause it first.")
|
|
|
|
mode = manifest["mode"]
|
|
|
|
# Collect all steps that need rerunning
|
|
all_rerun = set()
|
|
for step_name, step_update in updates.items():
|
|
if step_name not in manifest["steps"]:
|
|
return error_response(f"Unknown step: {step_name}")
|
|
|
|
content = step_update.get("content", {})
|
|
|
|
if rerun_from == "node":
|
|
# Input modified - save as new input, rerun from this step
|
|
affected = get_cascade_rerun_steps(mode, step_name)
|
|
else:
|
|
# Output modified - save as new output, rerun from next steps
|
|
affected = get_rerun_from_next(mode, step_name)
|
|
|
|
all_rerun.update(affected)
|
|
|
|
# Create new version
|
|
change_desc = f"修改步骤: {', '.join(updates.keys())} (rerun_from={rerun_from})"
|
|
new_version = create_new_version(pipeline_id, change_desc)
|
|
|
|
# Save modified artifacts
|
|
for step_name, step_update in updates.items():
|
|
content = step_update.get("content", {})
|
|
io_type = "input" if rerun_from == "node" else "output"
|
|
save_artifact(pipeline_id, new_version, step_name, io_type, content)
|
|
|
|
# Reset affected steps to pending
|
|
all_rerun_list = sorted(all_rerun, key=lambda s: manifest["steps"].get(s, {}).get("order", 0))
|
|
reset_steps(pipeline_id, all_rerun_list)
|
|
|
|
# Update manifest
|
|
manifest = get_manifest(pipeline_id)
|
|
manifest["state"] = PIPELINE_RUNNING
|
|
save_manifest(pipeline_id, manifest)
|
|
|
|
# Resume execution
|
|
await resume_pipeline(pipeline_id)
|
|
|
|
return json_response({
|
|
"status": "ok",
|
|
"pipeline_id": pipeline_id,
|
|
"new_version": new_version,
|
|
"rerun_steps": all_rerun_list,
|
|
"rerun_from": rerun_from,
|
|
"message": f"Created v{new_version}, rerunning {len(all_rerun_list)} steps",
|
|
})
|
|
|
|
|
|
# === Pause/Resume ===
|
|
|
|
async def handle_pause(request):
|
|
"""POST /v1/task/{id}/pause"""
|
|
pipeline_id = request.match_info["id"]
|
|
manifest = get_manifest(pipeline_id)
|
|
if not manifest:
|
|
return error_response(f"Pipeline not found: {pipeline_id}", 404)
|
|
|
|
await stop_pipeline(pipeline_id)
|
|
manifest["state"] = PIPELINE_PAUSED
|
|
save_manifest(pipeline_id, manifest)
|
|
|
|
return json_response({"status": "ok", "pipeline_id": pipeline_id, "state": PIPELINE_PAUSED})
|
|
|
|
|
|
async def handle_resume(request):
|
|
"""POST /v1/task/{id}/resume"""
|
|
pipeline_id = request.match_info["id"]
|
|
manifest = get_manifest(pipeline_id)
|
|
if not manifest:
|
|
return error_response(f"Pipeline not found: {pipeline_id}", 404)
|
|
|
|
manifest["state"] = PIPELINE_RUNNING
|
|
save_manifest(pipeline_id, manifest)
|
|
await resume_pipeline(pipeline_id)
|
|
|
|
return json_response({"status": "ok", "pipeline_id": pipeline_id, "state": PIPELINE_RUNNING})
|
|
|
|
|
|
# === CORS Middleware ===
|
|
|
|
@web.middleware
|
|
async def cors_middleware(request, handler):
|
|
response = await handler(request)
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
|
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
|
return response
|
|
|
|
|
|
# === App Setup ===
|
|
|
|
def create_app():
|
|
app = web.Application(middlewares=[cors_middleware])
|
|
|
|
app.router.add_get("/v1/health", handle_health)
|
|
app.router.add_post("/v1/task/submit", handle_submit)
|
|
app.router.add_get("/v1/tasks", handle_list)
|
|
app.router.add_get("/v1/task/{id}", handle_detail)
|
|
app.router.add_get("/v1/task/{id}/node/{step}", handle_node)
|
|
app.router.add_post("/v1/task/update", handle_update)
|
|
app.router.add_post("/v1/task/{id}/pause", handle_pause)
|
|
app.router.add_post("/v1/task/{id}/resume", handle_resume)
|
|
|
|
return app
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
|
|
port = int(os.environ.get("PIPELINE_PORT", "8190"))
|
|
host = os.environ.get("PIPELINE_HOST", "0.0.0.0")
|
|
|
|
app = create_app()
|
|
logger.info(f"Starting Pipeline Service on {host}:{port}")
|
|
web.run_app(app, host=host, port=port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|