362 lines
12 KiB
Python

"""Pipeline Service - aiohttp HTTP API server.
Endpoints:
POST /v1/task/submit - Create and start a new pipeline
GET /v1/tasks - List pipelines for a user
GET /v1/task/{id} - Get pipeline detail with node tree
GET /v1/task/{id}/node/{step} - Get node input/output
POST /v1/task/update - Modify node artifact + cascade rerun
POST /v1/task/{id}/pause - Pause a running pipeline
POST /v1/task/{id}/resume - Resume a paused pipeline
GET /v1/health - Health check
"""
import json
import logging
import os
from aiohttp import web
from .state import (
PIPELINE_SUBMITTED, PIPELINE_RUNNING, PIPELINE_COMPLETED, PIPELINE_FAILED, PIPELINE_PAUSED,
get_cascade_rerun_steps, get_rerun_from_next, build_dependency_map,
)
from .storage import (
create_pipeline, get_manifest, save_manifest,
get_artifact, save_artifact, get_all_artifacts,
create_new_version, reset_steps, get_user_pipelines,
)
from .executor import start_pipeline, resume_pipeline, stop_pipeline, is_running
logger = logging.getLogger("pipeline.api")
JSON_HEADERS = {"Content-Type": "application/json; charset=utf-8"}
def json_response(data, status=200):
return web.Response(
text=json.dumps(data, ensure_ascii=False),
status=status,
headers=JSON_HEADERS,
)
def error_response(message, status=400):
return json_response({"status": "error", "message": message}, status=status)
# === Health ===
async def handle_health(request):
return json_response({"status": "ok", "service": "pipeline-service", "version": "1.0.0"})
# === Submit New Pipeline ===
async def handle_submit(request):
"""POST /v1/task/submit
Body: {mode, title, lyrics?, input_audio?, input_video?, user_id, ...}
"""
try:
body = await request.json()
except Exception:
return error_response("Invalid JSON body")
mode = body.get("mode")
title = body.get("title")
user_id = body.get("user_id")
if not mode:
return error_response("Missing required field: mode")
if not title:
return error_response("Missing required field: title")
if not user_id:
return error_response("Missing required field: user_id")
# Validate mode
valid_modes = ["audio_lyrics", "video_lyrics", "lyrics_only"]
if mode not in valid_modes:
return error_response(f"Invalid mode: {mode}. Must be one of: {valid_modes}")
# Validate mode-specific requirements
if mode == "audio_lyrics" and not body.get("input_audio"):
return error_response("Mode audio_lyrics requires input_audio")
if mode == "video_lyrics" and not body.get("input_video"):
return error_response("Mode video_lyrics requires input_video")
# Extract params (everything except mode/title/user_id)
params = {k: v for k, v in body.items() if k not in ("mode", "title", "user_id")}
try:
manifest = create_pipeline(user_id, mode, title, params)
except Exception as e:
return error_response(f"Failed to create pipeline: {e}", 500)
pipeline_id = manifest["pipeline_id"]
# Start execution in background
await start_pipeline(pipeline_id)
return json_response({
"status": "ok",
"pipeline_id": pipeline_id,
"mode": mode,
"title": title,
"message": f"Pipeline created and started",
})
# === List User Pipelines ===
async def handle_list(request):
"""GET /v1/tasks?user_id=xxx"""
user_id = request.query.get("user_id")
if not user_id:
return error_response("Missing query param: user_id")
pipeline_ids = get_user_pipelines(user_id)
tasks = []
for pid in pipeline_ids:
manifest = get_manifest(pid)
if manifest:
tasks.append({
"pipeline_id": manifest["pipeline_id"],
"mode": manifest["mode"],
"title": manifest["title"],
"state": manifest["state"],
"current_version": manifest["current_version"],
"created_at": manifest["created_at"],
"updated_at": manifest["updated_at"],
})
return json_response({"status": "ok", "tasks": tasks, "total": len(tasks)})
# === Pipeline Detail ===
async def handle_detail(request):
"""GET /v1/task/{id}"""
pipeline_id = request.match_info["id"]
manifest = get_manifest(pipeline_id)
if not manifest:
return error_response(f"Pipeline not found: {pipeline_id}", 404)
version = manifest["current_version"]
artifacts = get_all_artifacts(pipeline_id, version)
# Build artifact summary
artifact_summary = {}
for key, art in artifacts.items():
artifact_summary[key] = {
"step": art.get("step"),
"type": art.get("type"),
"version": art.get("version"),
"saved_at": art.get("saved_at"),
}
return json_response({
"status": "ok",
"pipeline_id": manifest["pipeline_id"],
"user_id": manifest["user_id"],
"mode": manifest["mode"],
"title": manifest["title"],
"state": manifest["state"],
"current_version": version,
"created_at": manifest["created_at"],
"updated_at": manifest["updated_at"],
"steps": manifest["steps"],
"versions": manifest["versions"],
"artifacts": artifact_summary,
"is_running": is_running(pipeline_id),
})
# === Node Input/Output ===
async def handle_node(request):
"""GET /v1/task/{id}/node/{step}?version=N"""
pipeline_id = request.match_info["id"]
step = request.match_info["step"]
version = request.query.get("version")
manifest = get_manifest(pipeline_id)
if not manifest:
return error_response(f"Pipeline not found: {pipeline_id}", 404)
if step not in manifest["steps"]:
return error_response(f"Step not found: {step}")
v = int(version) if version else manifest["current_version"]
input_art = get_artifact(pipeline_id, v, step, "input")
output_art = get_artifact(pipeline_id, v, step, "output")
step_info = manifest["steps"][step]
return json_response({
"status": "ok",
"pipeline_id": pipeline_id,
"step": step,
"display_name": step_info.get("display_name", step),
"version": v,
"state": step_info["state"],
"input": input_art["data"] if input_art else None,
"output": output_art["data"] if output_art else None,
"input_version": input_art.get("version") if input_art else None,
"output_version": output_art.get("version") if output_art else None,
})
# === Modify Node + Cascade Rerun ===
async def handle_update(request):
"""POST /v1/task/update
Body: {pipeline_id, updates: {step: {content: ...}}, rerun_from: "node"|"next"}
rerun_from:
- "node": input changed, rerun from this step
- "next": output changed, rerun from next step(s)
"""
try:
body = await request.json()
except Exception:
return error_response("Invalid JSON body")
pipeline_id = body.get("pipeline_id")
updates = body.get("updates", {})
rerun_from = body.get("rerun_from", "node")
if not pipeline_id:
return error_response("Missing pipeline_id")
if not updates:
return error_response("Missing or empty updates")
manifest = get_manifest(pipeline_id)
if not manifest:
return error_response(f"Pipeline not found: {pipeline_id}", 404)
if manifest["state"] in (PIPELINE_RUNNING,):
if is_running(pipeline_id):
return error_response("Pipeline is currently running. Pause it first.")
mode = manifest["mode"]
# Collect all steps that need rerunning
all_rerun = set()
for step_name, step_update in updates.items():
if step_name not in manifest["steps"]:
return error_response(f"Unknown step: {step_name}")
content = step_update.get("content", {})
if rerun_from == "node":
# Input modified - save as new input, rerun from this step
affected = get_cascade_rerun_steps(mode, step_name)
else:
# Output modified - save as new output, rerun from next steps
affected = get_rerun_from_next(mode, step_name)
all_rerun.update(affected)
# Create new version
change_desc = f"修改步骤: {', '.join(updates.keys())} (rerun_from={rerun_from})"
new_version = create_new_version(pipeline_id, change_desc)
# Save modified artifacts
for step_name, step_update in updates.items():
content = step_update.get("content", {})
io_type = "input" if rerun_from == "node" else "output"
save_artifact(pipeline_id, new_version, step_name, io_type, content)
# Reset affected steps to pending
all_rerun_list = sorted(all_rerun, key=lambda s: manifest["steps"].get(s, {}).get("order", 0))
reset_steps(pipeline_id, all_rerun_list)
# Update manifest
manifest = get_manifest(pipeline_id)
manifest["state"] = PIPELINE_RUNNING
save_manifest(pipeline_id, manifest)
# Resume execution
await resume_pipeline(pipeline_id)
return json_response({
"status": "ok",
"pipeline_id": pipeline_id,
"new_version": new_version,
"rerun_steps": all_rerun_list,
"rerun_from": rerun_from,
"message": f"Created v{new_version}, rerunning {len(all_rerun_list)} steps",
})
# === Pause/Resume ===
async def handle_pause(request):
"""POST /v1/task/{id}/pause"""
pipeline_id = request.match_info["id"]
manifest = get_manifest(pipeline_id)
if not manifest:
return error_response(f"Pipeline not found: {pipeline_id}", 404)
await stop_pipeline(pipeline_id)
manifest["state"] = PIPELINE_PAUSED
save_manifest(pipeline_id, manifest)
return json_response({"status": "ok", "pipeline_id": pipeline_id, "state": PIPELINE_PAUSED})
async def handle_resume(request):
"""POST /v1/task/{id}/resume"""
pipeline_id = request.match_info["id"]
manifest = get_manifest(pipeline_id)
if not manifest:
return error_response(f"Pipeline not found: {pipeline_id}", 404)
manifest["state"] = PIPELINE_RUNNING
save_manifest(pipeline_id, manifest)
await resume_pipeline(pipeline_id)
return json_response({"status": "ok", "pipeline_id": pipeline_id, "state": PIPELINE_RUNNING})
# === CORS Middleware ===
@web.middleware
async def cors_middleware(request, handler):
response = await handler(request)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response
# === App Setup ===
def create_app():
app = web.Application(middlewares=[cors_middleware])
app.router.add_get("/v1/health", handle_health)
app.router.add_post("/v1/task/submit", handle_submit)
app.router.add_get("/v1/tasks", handle_list)
app.router.add_get("/v1/task/{id}", handle_detail)
app.router.add_get("/v1/task/{id}/node/{step}", handle_node)
app.router.add_post("/v1/task/update", handle_update)
app.router.add_post("/v1/task/{id}/pause", handle_pause)
app.router.add_post("/v1/task/{id}/resume", handle_resume)
return app
def main():
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
port = int(os.environ.get("PIPELINE_PORT", "8190"))
host = os.environ.get("PIPELINE_HOST", "0.0.0.0")
app = create_app()
logger.info(f"Starting Pipeline Service on {host}:{port}")
web.run_app(app, host=host, port=port)
if __name__ == "__main__":
main()