237 lines
8.9 KiB
Python

"""Pipeline storage layer - MySQL via sqlor."""
import json
from typing import Dict, List, Optional
from sqlor.dbpools import DBPools
from appPublic.uniqueID import getID
from appPublic.log import debug
DBNAME = "pipeline"
def _get_db():
return DBPools(), DBNAME
async def get_pipeline_steps(pipeline_id: str) -> list:
"""Read step definitions from pipeline_steps table (defined by pipeline_core).
Extracts 'deps' from step_config JSON and injects it as a top-level field
so that build_step_graph() can find it.
"""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
recs = await sor.R('pipeline_steps', {'pipeline_id': pipeline_id}, sort='step_order')
if not recs:
return []
result = []
for rec in recs:
if hasattr(rec, '__dict__'):
d = {k: getattr(rec, k) for k in dir(rec) if not k.startswith('_')}
else:
d = dict(rec)
# Extract deps from step_config JSON
cfg_raw = d.get('step_config', '{}')
try:
cfg = json.loads(cfg_raw) if isinstance(cfg_raw, str) else cfg_raw
except (json.JSONDecodeError, TypeError):
cfg = {}
d['deps'] = cfg.get('deps', [])
result.append(d)
return result
async def create_task(tenant_id: str, pipeline_id: str, owner_id: str, title: str, params: dict) -> str:
"""Create a new pipeline task. Returns task_id."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
task_id = getID()
data = {
"id": task_id,
"tenant_id": tenant_id,
"pipeline_id": pipeline_id,
"owner_id": owner_id,
"title": title,
"state": "submitted",
"current_version": 1,
"params": json.dumps(params, ensure_ascii=False, default=str),
}
await sor.C('pipeline_tasks', data)
return task_id
async def init_task_steps(task_id: str, step_records: list):
"""Create step execution records from pipeline step definitions."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
for rec in step_records:
step_id = getID()
data = {
"id": step_id,
"task_id": task_id,
"step_name": rec['step_name'],
"step_type": rec.get('step_type', rec['step_name']),
"display_name": rec.get('display_name', rec['step_name']),
"step_order": rec.get('step_order', 0),
"deps": rec.get('deps', '[]') if isinstance(rec.get('deps'), str) else json.dumps(rec.get('deps', [])),
"state": "pending",
}
await sor.C('pipeline_task_steps', data)
async def get_task(tenant_id: str, task_id: str) -> Optional[dict]:
"""Get task record, filtered by tenant."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
recs = await sor.R('pipeline_tasks', {'id': task_id, 'tenant_id': tenant_id})
if not recs:
return None
rec = recs[0]
if hasattr(rec, '__dict__'):
return {k: getattr(rec, k) for k in dir(rec) if not k.startswith('_')}
return dict(rec)
async def get_task_steps(task_id: str) -> list:
"""Get all step records for a task."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
recs = await sor.R('pipeline_task_steps', {'task_id': task_id}, sort='step_order')
result = []
for rec in (recs or []):
if hasattr(rec, '__dict__'):
result.append({k: getattr(rec, k) for k in dir(rec) if not k.startswith('_')})
else:
result.append(dict(rec))
return result
async def get_step_states(task_id: str) -> Dict[str, str]:
"""Get {step_name: state} for all steps of a task."""
steps = await get_task_steps(task_id)
return {s['step_name']: s['state'] for s in steps}
async def update_task_state(task_id: str, state: str):
"""Update task state."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
await sor.U('pipeline_tasks', {"id": task_id, "state": state})
async def update_task_version(task_id: str, version: int):
"""Update task current_version."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
await sor.U('pipeline_tasks', {"id": task_id, "current_version": version})
async def update_step_state(task_id: str, step_name: str, state: str, error_msg: str = None):
"""Update step state."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
recs = await sor.R('pipeline_task_steps', {'task_id': task_id, 'step_name': step_name})
if not recs:
return
rec = recs[0]
rec_id = rec.id if hasattr(rec, 'id') else rec['id']
data = {"id": rec_id, "state": state}
if state == "running":
data["started_at"] = "NOW()"
elif state in ("completed", "failed"):
data["completed_at"] = "NOW()"
if error_msg:
data["error_msg"] = error_msg
await sor.U('pipeline_task_steps', data)
async def save_artifact(task_id: str, version: int, step_name: str, io_type: str, data: dict):
"""Save or update an artifact."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
# Check if exists
existing = await sor.R('pipeline_artifacts', {
'task_id': task_id, 'version': version,
'step_name': step_name, 'io_type': io_type
})
data_json = json.dumps(data, ensure_ascii=False, default=str)
if existing:
rec = existing[0]
rec_id = rec.id if hasattr(rec, 'id') else rec['id']
await sor.U('pipeline_artifacts', {"id": rec_id, "data": data_json})
else:
await sor.C('pipeline_artifacts', {
"id": getID(),
"task_id": task_id,
"version": version,
"step_name": step_name,
"io_type": io_type,
"data": data_json,
})
async def get_artifact(task_id: str, version: int, step_name: str, io_type: str) -> Optional[dict]:
"""Get artifact data."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
recs = await sor.R('pipeline_artifacts', {
'task_id': task_id, 'version': version,
'step_name': step_name, 'io_type': io_type
})
if not recs:
return None
rec = recs[0]
raw = rec.data if hasattr(rec, 'data') else rec['data']
if isinstance(raw, str):
return json.loads(raw)
return raw
async def get_all_artifacts(task_id: str, version: int) -> Dict[str, dict]:
"""Get all artifacts for a task version. Returns {step_name_io_type: data}."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
recs = await sor.R('pipeline_artifacts', {'task_id': task_id, 'version': version})
result = {}
for rec in (recs or []):
sn = rec.step_name if hasattr(rec, 'step_name') else rec['step_name']
io = rec.io_type if hasattr(rec, 'io_type') else rec['io_type']
raw = rec.data if hasattr(rec, 'data') else rec['data']
key = f"{sn}_{io}"
result[key] = json.loads(raw) if isinstance(raw, str) else raw
return result
async def list_tasks(tenant_id: str, pipeline_id: str = None, limit: int = 100) -> list:
"""List tasks for a tenant."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
if pipeline_id:
sql = "SELECT * FROM pipeline_tasks WHERE tenant_id=${tenant_id}$ AND pipeline_id=${pipeline_id}$ ORDER BY created_at DESC"
recs = await sor.sqlExe(sql, {'tenant_id': tenant_id, 'pipeline_id': pipeline_id})
else:
sql = "SELECT * FROM pipeline_tasks WHERE tenant_id=${tenant_id}$ ORDER BY created_at DESC"
recs = await sor.sqlExe(sql, {'tenant_id': tenant_id})
result = []
for rec in (recs or [])[:limit]:
if hasattr(rec, '__dict__'):
result.append({k: getattr(rec, k) for k in dir(rec) if not k.startswith('_')})
else:
result.append(dict(rec))
return result
async def reset_steps(task_id: str, step_names: list):
"""Reset specified steps to pending state."""
db, dbname = _get_db()
async with db.sqlorContext(dbname) as sor:
for sn in step_names:
recs = await sor.R('pipeline_task_steps', {'task_id': task_id, 'step_name': sn})
if recs:
rec = recs[0]
rec_id = rec.id if hasattr(rec, 'id') else rec['id']
await sor.U('pipeline_task_steps', {
"id": rec_id, "state": "pending",
"error_msg": None, "started_at": None, "completed_at": None
})