feat: add workflow run dispatch foundation

This commit is contained in:
2026-04-07 10:11:46 +02:00
parent ab1b220e79
commit 6ad34ceed2
7 changed files with 548 additions and 54 deletions
@@ -17,8 +17,13 @@ Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a
warning — this lets you add new StepName values to the enum without breaking
existing dispatchers.
"""
from __future__ import annotations
import logging
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Literal
from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode
from app.core.process_steps import StepName
@@ -26,6 +31,25 @@ from app.core.process_steps import StepName
logger = logging.getLogger(__name__)
WorkflowExecutionMode = Literal["legacy", "graph", "shadow"]
@dataclass(slots=True)
class WorkflowContext:
context_id: str
execution_mode: WorkflowExecutionMode
workflow_run_id: uuid.UUID | None = None
ordered_nodes: list[WorkflowNode] = field(default_factory=list)
@dataclass(slots=True)
class WorkflowDispatchResult:
context: WorkflowContext
task_ids: list[str]
node_task_ids: dict[str, str]
skipped_node_ids: list[str]
# ---------------------------------------------------------------------------
# Step → Celery task name mapping
#
@@ -59,7 +83,76 @@ STEP_TASK_MAP: dict[StepName, str] = {
}
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
def prepare_workflow_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowContext:
"""Validate workflow config and build the runtime context."""
config = WorkflowConfig.model_validate(workflow_config)
ordered_nodes = _topological_sort(config)
return WorkflowContext(
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
ordered_nodes=ordered_nodes,
)
def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult:
"""Execute prepared workflow nodes in topological order as individual Celery tasks."""
from app.tasks.celery_app import celery_app
task_ids: list[str] = []
node_task_ids: dict[str, str] = {}
skipped_node_ids: list[str] = []
for node in context.ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r",
node.step,
context.execution_mode,
node.id,
)
skipped_node_ids.append(node.id)
continue
result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params)
task_ids.append(result.id)
node_task_ids[node.id] = result.id
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
node.id,
node.step,
context.execution_mode,
context.workflow_run_id,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context.context_id,
)
return WorkflowDispatchResult(
context=context,
task_ids=task_ids,
node_task_ids=node_task_ids,
skipped_node_ids=skipped_node_ids,
)
def dispatch_workflow_with_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowDispatchResult:
"""Execute workflow nodes in topological order as individual Celery tasks.
Args:
@@ -70,47 +163,25 @@ def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
tasks).
Returns:
List of dispatched Celery task ID strings (one per mapped node).
Dispatch result including runtime context, dispatched task IDs and skipped nodes.
Raises:
pydantic.ValidationError: if *workflow_config* does not conform to
WorkflowConfig schema.
ValueError: if the node graph contains a cycle.
"""
from app.tasks.celery_app import celery_app
# Validate + parse config
config = WorkflowConfig.model_validate(workflow_config)
# Topological sort so that dependency nodes dispatch first
ordered_nodes = _topological_sort(config)
task_ids: list[str] = []
for node in ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r — skipping node %r",
node.step,
node.id,
)
continue
result = celery_app.send_task(task_name, args=[context_id], kwargs=node.params)
task_ids.append(result.id)
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s) → Celery task %s",
node.id,
node.step,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context_id,
context = prepare_workflow_context(
workflow_config,
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
)
return task_ids
return dispatch_prepared_workflow(context)
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
"""Backward-compatible wrapper returning only dispatched Celery task IDs."""
return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids
# ---------------------------------------------------------------------------