feat: add workflow run dispatch foundation

This commit is contained in:
2026-04-07 10:11:46 +02:00
parent ab1b220e79
commit 6ad34ceed2
7 changed files with 548 additions and 54 deletions
@@ -13,7 +13,6 @@ task-dispatch logic in app.services.render_dispatcher (legacy path).
from __future__ import annotations
import logging
from datetime import datetime
logger = logging.getLogger(__name__)
@@ -32,7 +31,9 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
from app.config import settings
from app.domains.orders.models import OrderLine
from app.domains.rendering.models import OutputType, WorkflowDefinition, WorkflowRun
from app.domains.rendering.models import OutputType, WorkflowDefinition
from app.domains.rendering.workflow_executor import prepare_workflow_context
from app.domains.rendering.workflow_run_service import create_workflow_run, mark_workflow_run_failed
engine = create_engine(
settings.database_url.replace("+asyncpg", ""),
@@ -96,6 +97,22 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
workflow_type,
)
try:
workflow_context = prepare_workflow_context(
wf_def.config,
context_id=order_line_id,
execution_mode="legacy",
)
except Exception as exc:
logger.warning(
"order_line %s: workflow_definition_id %s failed runtime preparation (%s), "
"falling back to legacy dispatch",
order_line_id,
wf_def.id,
exc,
)
return _legacy_dispatch(order_line_id)
# For turntable workflows: resolve step_path + output_dir from the order line at runtime
if workflow_type == "turntable" and ("step_path" not in params or "output_dir" not in params):
from app.domains.products.models import CadFile as _CadFile
@@ -120,19 +137,43 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
str(_Path(_cfg.upload_dir) / "renders" / str(line.id)),
)
from app.domains.rendering.workflow_builder import dispatch_workflow
celery_task_id = dispatch_workflow(workflow_type, order_line_id, params)
run = None
try:
run = create_workflow_run(
session,
workflow_def_id=wf_def.id,
order_line_id=line.id,
workflow_context=workflow_context,
)
session.commit()
except Exception as exc:
session.rollback()
logger.warning(
"order_line %s: failed to create workflow run for workflow_definition_id %s (%s), "
"falling back to legacy dispatch",
order_line_id,
wf_def.id,
exc,
)
return _legacy_dispatch(order_line_id)
# Persist a WorkflowRun record
run = WorkflowRun(
workflow_def_id=wf_def.id,
order_line_id=line.id,
celery_task_id=celery_task_id,
status="pending",
started_at=datetime.utcnow(),
)
session.add(run)
session.commit()
from app.domains.rendering.workflow_builder import dispatch_workflow
try:
celery_task_id = dispatch_workflow(workflow_type, order_line_id, params)
run.celery_task_id = celery_task_id
session.commit()
except Exception as exc:
session.rollback()
session.add(run)
mark_workflow_run_failed(run, str(exc))
session.commit()
logger.exception(
"order_line %s: workflow dispatch via definition %s failed, falling back to legacy dispatch",
order_line_id,
wf_def.id,
)
return _legacy_dispatch(order_line_id)
return {
"backend": "workflow",
@@ -17,8 +17,13 @@ Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a
warning — this lets you add new StepName values to the enum without breaking
existing dispatchers.
"""
from __future__ import annotations
import logging
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Literal
from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode
from app.core.process_steps import StepName
@@ -26,6 +31,25 @@ from app.core.process_steps import StepName
logger = logging.getLogger(__name__)
WorkflowExecutionMode = Literal["legacy", "graph", "shadow"]
@dataclass(slots=True)
class WorkflowContext:
context_id: str
execution_mode: WorkflowExecutionMode
workflow_run_id: uuid.UUID | None = None
ordered_nodes: list[WorkflowNode] = field(default_factory=list)
@dataclass(slots=True)
class WorkflowDispatchResult:
context: WorkflowContext
task_ids: list[str]
node_task_ids: dict[str, str]
skipped_node_ids: list[str]
# ---------------------------------------------------------------------------
# Step → Celery task name mapping
#
@@ -59,7 +83,76 @@ STEP_TASK_MAP: dict[StepName, str] = {
}
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
def prepare_workflow_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowContext:
"""Validate workflow config and build the runtime context."""
config = WorkflowConfig.model_validate(workflow_config)
ordered_nodes = _topological_sort(config)
return WorkflowContext(
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
ordered_nodes=ordered_nodes,
)
def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult:
"""Execute prepared workflow nodes in topological order as individual Celery tasks."""
from app.tasks.celery_app import celery_app
task_ids: list[str] = []
node_task_ids: dict[str, str] = {}
skipped_node_ids: list[str] = []
for node in context.ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r",
node.step,
context.execution_mode,
node.id,
)
skipped_node_ids.append(node.id)
continue
result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params)
task_ids.append(result.id)
node_task_ids[node.id] = result.id
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
node.id,
node.step,
context.execution_mode,
context.workflow_run_id,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context.context_id,
)
return WorkflowDispatchResult(
context=context,
task_ids=task_ids,
node_task_ids=node_task_ids,
skipped_node_ids=skipped_node_ids,
)
def dispatch_workflow_with_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowDispatchResult:
"""Execute workflow nodes in topological order as individual Celery tasks.
Args:
@@ -70,47 +163,25 @@ def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
tasks).
Returns:
List of dispatched Celery task ID strings (one per mapped node).
Dispatch result including runtime context, dispatched task IDs and skipped nodes.
Raises:
pydantic.ValidationError: if *workflow_config* does not conform to
WorkflowConfig schema.
ValueError: if the node graph contains a cycle.
"""
from app.tasks.celery_app import celery_app
# Validate + parse config
config = WorkflowConfig.model_validate(workflow_config)
# Topological sort so that dependency nodes dispatch first
ordered_nodes = _topological_sort(config)
task_ids: list[str] = []
for node in ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r — skipping node %r",
node.step,
node.id,
)
continue
result = celery_app.send_task(task_name, args=[context_id], kwargs=node.params)
task_ids.append(result.id)
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s) → Celery task %s",
node.id,
node.step,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context_id,
context = prepare_workflow_context(
workflow_config,
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
)
return task_ids
return dispatch_prepared_workflow(context)
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
"""Backward-compatible wrapper returning only dispatched Celery task IDs."""
return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids
# ---------------------------------------------------------------------------
@@ -200,6 +200,9 @@ async def list_workflow_runs(
class WorkflowDispatchResponse(BaseModel):
workflow_run: WorkflowRunOut
context_id: str
execution_mode: str
dispatched: int
task_ids: list[str]
@@ -225,7 +228,15 @@ async def dispatch_workflow_endpoint(
the caller can track progress.
"""
from pydantic import ValidationError as _ValidationError
from app.domains.rendering.workflow_executor import dispatch_workflow
from app.domains.rendering.workflow_executor import (
dispatch_prepared_workflow,
prepare_workflow_context,
)
from app.domains.rendering.workflow_run_service import (
apply_graph_dispatch_result,
create_workflow_run,
mark_workflow_run_failed,
)
result = await db.execute(
select(WorkflowDefinition).where(WorkflowDefinition.id == workflow_id)
@@ -237,10 +248,59 @@ async def dispatch_workflow_endpoint(
raise HTTPException(status_code=400, detail="Workflow has no config")
try:
task_ids = dispatch_workflow(wf.config, context_id)
workflow_context = prepare_workflow_context(
wf.config,
context_id=context_id,
execution_mode="graph",
)
except _ValidationError as exc:
raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}")
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc))
return WorkflowDispatchResponse(dispatched=len(task_ids), task_ids=task_ids)
run_id = await db.run_sync(
lambda sync_session: create_workflow_run(
sync_session,
workflow_def_id=wf.id,
order_line_id=None,
workflow_context=workflow_context,
).id
)
await db.commit()
try:
dispatch_result = dispatch_prepared_workflow(workflow_context)
except Exception as exc:
failed_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id)
.options(selectinload(WorkflowRun.node_results))
)
failed_run = failed_result.scalar_one()
mark_workflow_run_failed(failed_run, str(exc))
await db.commit()
raise
run_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id)
.options(selectinload(WorkflowRun.node_results))
)
run = run_result.scalar_one()
apply_graph_dispatch_result(run, workflow_context, dispatch_result)
await db.commit()
refreshed_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id)
.options(selectinload(WorkflowRun.node_results))
)
refreshed_run = refreshed_result.scalar_one()
return WorkflowDispatchResponse(
workflow_run=refreshed_run,
context_id=context_id,
execution_mode=workflow_context.execution_mode,
dispatched=len(dispatch_result.task_ids),
task_ids=dispatch_result.task_ids,
)
@@ -0,0 +1,81 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy.orm import Session
from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun
from app.domains.rendering.workflow_executor import WorkflowContext, WorkflowDispatchResult
def create_workflow_run(
session: Session,
*,
workflow_def_id,
order_line_id,
workflow_context: WorkflowContext,
) -> WorkflowRun:
run = WorkflowRun(
workflow_def_id=workflow_def_id,
order_line_id=order_line_id,
status="pending",
started_at=datetime.utcnow(),
)
session.add(run)
session.flush()
workflow_context.workflow_run_id = run.id
for node in workflow_context.ordered_nodes:
metadata = {"step": node.step.value}
if node.ui and node.ui.label:
metadata["label"] = node.ui.label
session.add(
WorkflowNodeResult(
run_id=run.id,
node_name=node.id,
status="pending",
output=metadata,
)
)
session.flush()
return run
def apply_graph_dispatch_result(
run: WorkflowRun,
workflow_context: WorkflowContext,
dispatch_result: WorkflowDispatchResult,
) -> None:
node_map = {node.id: node for node in workflow_context.ordered_nodes}
results_by_name = {node_result.node_name: node_result for node_result in run.node_results}
run.celery_task_id = dispatch_result.task_ids[0] if dispatch_result.task_ids else None
for node_id, node_result in results_by_name.items():
node = node_map.get(node_id)
if node is None:
continue
metadata = dict(node_result.output or {})
metadata.setdefault("step", node.step.value)
if node.ui and node.ui.label:
metadata.setdefault("label", node.ui.label)
task_id = dispatch_result.node_task_ids.get(node_id)
if task_id is not None:
node_result.status = "queued"
metadata["task_id"] = task_id
node_result.output = metadata
continue
if node_id in dispatch_result.skipped_node_ids:
node_result.status = "skipped"
node_result.output = metadata
node_result.log = f"No Celery task mapping for step '{node.step.value}'"
def mark_workflow_run_failed(run: WorkflowRun, error_message: str) -> None:
run.status = "failed"
run.completed_at = datetime.utcnow()
run.error_message = error_message[:2000]