183 lines
6.1 KiB
Python

"""Pipeline step executor - runs steps asynchronously.
Each step function receives (pipeline_id, manifest, params) and returns output data.
The executor handles state transitions, artifact storage, and dependency resolution.
"""
import asyncio
import logging
from typing import Dict, Optional
from .state import (
STATE_PENDING, STATE_RUNNING, STATE_COMPLETED, STATE_FAILED,
PIPELINE_RUNNING, PIPELINE_COMPLETED, PIPELINE_FAILED, PIPELINE_PAUSED,
build_dependency_map,
)
from .storage import (
get_manifest, save_manifest, save_artifact,
update_step_state, create_new_version, reset_steps,
)
logger = logging.getLogger("pipeline.executor")
# Active pipeline execution tasks (pipeline_id -> asyncio.Task)
_active_tasks: Dict[str, asyncio.Task] = {}
async def start_pipeline(pipeline_id: str):
"""Start executing a pipeline from the beginning."""
task = asyncio.create_task(_run_pipeline(pipeline_id))
_active_tasks[pipeline_id] = task
return task
async def resume_pipeline(pipeline_id: str):
"""Resume a paused pipeline."""
task = asyncio.create_task(_run_pipeline(pipeline_id))
_active_tasks[pipeline_id] = task
return task
async def stop_pipeline(pipeline_id: str):
"""Cancel a running pipeline."""
task = _active_tasks.get(pipeline_id)
if task and not task.done():
task.cancel()
del _active_tasks[pipeline_id]
return True
return False
def is_running(pipeline_id: str) -> bool:
"""Check if a pipeline is currently running."""
task = _active_tasks.get(pipeline_id)
return task is not None and not task.done()
async def _run_pipeline(pipeline_id: str):
"""Main pipeline execution loop."""
try:
while True:
manifest = get_manifest(pipeline_id)
if not manifest:
logger.error(f"Pipeline {pipeline_id} not found")
break
if manifest["state"] == PIPELINE_PAUSED:
logger.info(f"Pipeline {pipeline_id} paused, waiting for user action")
break
if manifest["state"] == PIPELINE_COMPLETED:
logger.info(f"Pipeline {pipeline_id} already completed")
break
# Find next pending step with all deps completed
next_step = _find_next_step(manifest)
if not next_step:
# All steps done
manifest["state"] = PIPELINE_COMPLETED
save_manifest(pipeline_id, manifest)
logger.info(f"Pipeline {pipeline_id} completed")
break
# Execute the step
await _execute_step(pipeline_id, next_step)
except asyncio.CancelledError:
logger.info(f"Pipeline {pipeline_id} cancelled")
except Exception as e:
logger.error(f"Pipeline {pipeline_id} error: {e}")
finally:
_active_tasks.pop(pipeline_id, None)
def _find_next_step(manifest: dict) -> Optional[str]:
"""Find the next step to execute (pending, all deps completed)."""
steps = manifest.get("steps", {})
for name, info in sorted(steps.items(), key=lambda x: x[1].get("order", 0)):
if info["state"] != STATE_PENDING:
continue
# Check all dependencies are completed
deps_ok = all(
steps.get(dep, {}).get("state") == STATE_COMPLETED
for dep in info.get("deps", [])
)
if deps_ok:
return name
return None
async def _execute_step(pipeline_id: str, step_name: str):
"""Execute a single pipeline step."""
update_step_state(pipeline_id, step_name, STATE_RUNNING)
try:
# Load input artifacts from dependencies
manifest = get_manifest(pipeline_id)
version = manifest["current_version"]
input_data = await _gather_step_inputs(pipeline_id, step_name, version, manifest)
# Save input artifact
save_artifact(pipeline_id, version, step_name, "input", input_data)
# Execute step handler
handler = STEP_HANDLERS.get(step_name, _default_handler)
output_data = await handler(pipeline_id, step_name, input_data, manifest)
# Save output artifact
save_artifact(pipeline_id, version, step_name, "output", output_data)
update_step_state(pipeline_id, step_name, STATE_COMPLETED)
logger.info(f"Step {step_name} completed for {pipeline_id}")
except Exception as e:
logger.error(f"Step {step_name} failed for {pipeline_id}: {e}")
update_step_state(pipeline_id, step_name, STATE_FAILED, str(e))
async def _gather_step_inputs(pipeline_id: str, step_name: str, version: int, manifest: dict) -> dict:
"""Gather input data for a step from its dependencies' outputs."""
from .storage import get_artifact
steps = manifest.get("steps", {})
step_info = steps.get(step_name, {})
deps = step_info.get("deps", [])
inputs = {}
for dep in deps:
dep_artifact = get_artifact(pipeline_id, version, dep, "output")
if dep_artifact:
inputs[dep] = dep_artifact.get("data", {})
# Also include initial params
if "_params" in inputs:
inputs["params"] = inputs.pop("_params")
# Check if there's a user-modified input for this step
user_input = get_artifact(pipeline_id, version, step_name, "input")
if user_input and user_input.get("data"):
# User modified input exists, use it
return user_input["data"]
return inputs
# === Step Handlers ===
async def _default_handler(pipeline_id: str, step_name: str, input_data: dict, manifest: dict) -> dict:
"""Default handler - stub that returns input as output (for testing)."""
logger.info(f"Default handler for {step_name}, pipeline {pipeline_id}")
# Simulate some processing time
await asyncio.sleep(0.5)
return {"step": step_name, "status": "completed", "input_summary": str(input_data)[:200]}
# Real handlers will be registered here
STEP_HANDLERS = {
# Will be populated by actual KTV pipeline step implementations
}
def register_handler(step_name: str, handler):
"""Register a step handler function."""
STEP_HANDLERS[step_name] = handler