183 lines
6.1 KiB
Python
183 lines
6.1 KiB
Python
"""Pipeline step executor - runs steps asynchronously.
|
|
|
|
Each step function receives (pipeline_id, manifest, params) and returns output data.
|
|
The executor handles state transitions, artifact storage, and dependency resolution.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Optional
|
|
|
|
from .state import (
|
|
STATE_PENDING, STATE_RUNNING, STATE_COMPLETED, STATE_FAILED,
|
|
PIPELINE_RUNNING, PIPELINE_COMPLETED, PIPELINE_FAILED, PIPELINE_PAUSED,
|
|
build_dependency_map,
|
|
)
|
|
from .storage import (
|
|
get_manifest, save_manifest, save_artifact,
|
|
update_step_state, create_new_version, reset_steps,
|
|
)
|
|
|
|
logger = logging.getLogger("pipeline.executor")
|
|
|
|
# Active pipeline execution tasks (pipeline_id -> asyncio.Task)
|
|
_active_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
|
|
async def start_pipeline(pipeline_id: str):
|
|
"""Start executing a pipeline from the beginning."""
|
|
task = asyncio.create_task(_run_pipeline(pipeline_id))
|
|
_active_tasks[pipeline_id] = task
|
|
return task
|
|
|
|
|
|
async def resume_pipeline(pipeline_id: str):
|
|
"""Resume a paused pipeline."""
|
|
task = asyncio.create_task(_run_pipeline(pipeline_id))
|
|
_active_tasks[pipeline_id] = task
|
|
return task
|
|
|
|
|
|
async def stop_pipeline(pipeline_id: str):
|
|
"""Cancel a running pipeline."""
|
|
task = _active_tasks.get(pipeline_id)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
del _active_tasks[pipeline_id]
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_running(pipeline_id: str) -> bool:
|
|
"""Check if a pipeline is currently running."""
|
|
task = _active_tasks.get(pipeline_id)
|
|
return task is not None and not task.done()
|
|
|
|
|
|
async def _run_pipeline(pipeline_id: str):
|
|
"""Main pipeline execution loop."""
|
|
try:
|
|
while True:
|
|
manifest = get_manifest(pipeline_id)
|
|
if not manifest:
|
|
logger.error(f"Pipeline {pipeline_id} not found")
|
|
break
|
|
|
|
if manifest["state"] == PIPELINE_PAUSED:
|
|
logger.info(f"Pipeline {pipeline_id} paused, waiting for user action")
|
|
break
|
|
|
|
if manifest["state"] == PIPELINE_COMPLETED:
|
|
logger.info(f"Pipeline {pipeline_id} already completed")
|
|
break
|
|
|
|
# Find next pending step with all deps completed
|
|
next_step = _find_next_step(manifest)
|
|
if not next_step:
|
|
# All steps done
|
|
manifest["state"] = PIPELINE_COMPLETED
|
|
save_manifest(pipeline_id, manifest)
|
|
logger.info(f"Pipeline {pipeline_id} completed")
|
|
break
|
|
|
|
# Execute the step
|
|
await _execute_step(pipeline_id, next_step)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Pipeline {pipeline_id} cancelled")
|
|
except Exception as e:
|
|
logger.error(f"Pipeline {pipeline_id} error: {e}")
|
|
finally:
|
|
_active_tasks.pop(pipeline_id, None)
|
|
|
|
|
|
def _find_next_step(manifest: dict) -> Optional[str]:
|
|
"""Find the next step to execute (pending, all deps completed)."""
|
|
steps = manifest.get("steps", {})
|
|
for name, info in sorted(steps.items(), key=lambda x: x[1].get("order", 0)):
|
|
if info["state"] != STATE_PENDING:
|
|
continue
|
|
# Check all dependencies are completed
|
|
deps_ok = all(
|
|
steps.get(dep, {}).get("state") == STATE_COMPLETED
|
|
for dep in info.get("deps", [])
|
|
)
|
|
if deps_ok:
|
|
return name
|
|
return None
|
|
|
|
|
|
async def _execute_step(pipeline_id: str, step_name: str):
|
|
"""Execute a single pipeline step."""
|
|
update_step_state(pipeline_id, step_name, STATE_RUNNING)
|
|
|
|
try:
|
|
# Load input artifacts from dependencies
|
|
manifest = get_manifest(pipeline_id)
|
|
version = manifest["current_version"]
|
|
input_data = await _gather_step_inputs(pipeline_id, step_name, version, manifest)
|
|
|
|
# Save input artifact
|
|
save_artifact(pipeline_id, version, step_name, "input", input_data)
|
|
|
|
# Execute step handler
|
|
handler = STEP_HANDLERS.get(step_name, _default_handler)
|
|
output_data = await handler(pipeline_id, step_name, input_data, manifest)
|
|
|
|
# Save output artifact
|
|
save_artifact(pipeline_id, version, step_name, "output", output_data)
|
|
update_step_state(pipeline_id, step_name, STATE_COMPLETED)
|
|
|
|
logger.info(f"Step {step_name} completed for {pipeline_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Step {step_name} failed for {pipeline_id}: {e}")
|
|
update_step_state(pipeline_id, step_name, STATE_FAILED, str(e))
|
|
|
|
|
|
async def _gather_step_inputs(pipeline_id: str, step_name: str, version: int, manifest: dict) -> dict:
|
|
"""Gather input data for a step from its dependencies' outputs."""
|
|
from .storage import get_artifact
|
|
steps = manifest.get("steps", {})
|
|
step_info = steps.get(step_name, {})
|
|
deps = step_info.get("deps", [])
|
|
|
|
inputs = {}
|
|
for dep in deps:
|
|
dep_artifact = get_artifact(pipeline_id, version, dep, "output")
|
|
if dep_artifact:
|
|
inputs[dep] = dep_artifact.get("data", {})
|
|
|
|
# Also include initial params
|
|
if "_params" in inputs:
|
|
inputs["params"] = inputs.pop("_params")
|
|
|
|
# Check if there's a user-modified input for this step
|
|
user_input = get_artifact(pipeline_id, version, step_name, "input")
|
|
if user_input and user_input.get("data"):
|
|
# User modified input exists, use it
|
|
return user_input["data"]
|
|
|
|
return inputs
|
|
|
|
|
|
# === Step Handlers ===
|
|
|
|
async def _default_handler(pipeline_id: str, step_name: str, input_data: dict, manifest: dict) -> dict:
|
|
"""Default handler - stub that returns input as output (for testing)."""
|
|
logger.info(f"Default handler for {step_name}, pipeline {pipeline_id}")
|
|
# Simulate some processing time
|
|
await asyncio.sleep(0.5)
|
|
return {"step": step_name, "status": "completed", "input_summary": str(input_data)[:200]}
|
|
|
|
|
|
# Real handlers will be registered here
|
|
STEP_HANDLERS = {
|
|
# Will be populated by actual KTV pipeline step implementations
|
|
}
|
|
|
|
|
|
def register_handler(step_name: str, handler):
|
|
"""Register a step handler function."""
|
|
STEP_HANDLERS[step_name] = handler
|