"""Execute a validated WorkflowDefinition config as individual Celery tasks. Usage:: from app.domains.rendering.workflow_executor import dispatch_workflow task_ids = dispatch_workflow(workflow_definition.config, context_id=str(order_line_id)) The function: 1. Validates the raw config dict via WorkflowConfig. 2. Sorts nodes into dependency order (topological sort, Kahn's algorithm). 3. Maps each StepName to the corresponding Celery task name string. 4. Sends each task independently (not chained) so that every node gets its own Celery task ID and can be tracked individually. Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a warning — this lets you add new StepName values to the enum without breaking existing dispatchers. """ from __future__ import annotations import logging import uuid from collections import deque from dataclasses import dataclass, field from typing import Literal from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode from app.core.process_steps import StepName logger = logging.getLogger(__name__) WorkflowExecutionMode = Literal["legacy", "graph", "shadow"] @dataclass(slots=True) class WorkflowContext: context_id: str execution_mode: WorkflowExecutionMode workflow_run_id: uuid.UUID | None = None ordered_nodes: list[WorkflowNode] = field(default_factory=list) @dataclass(slots=True) class WorkflowDispatchResult: context: WorkflowContext task_ids: list[str] node_task_ids: dict[str, str] skipped_node_ids: list[str] # --------------------------------------------------------------------------- # Step → Celery task name mapping # # Values are the full dotted task name strings passed to celery_app.send_task(). # Keep this in sync with the @celery_app.task(name=...) decorators in the # various tasks modules. # --------------------------------------------------------------------------- STEP_TASK_MAP: dict[StepName, str] = { # ── STEP file processing ───────────────────────────────────────────── StepName.RESOLVE_STEP_PATH: "app.tasks.step_tasks.process_step_file", StepName.OCC_OBJECT_EXTRACT: "app.tasks.step_tasks.process_step_file", StepName.OCC_GLB_EXPORT: "app.tasks.step_tasks.generate_gltf_geometry_task", StepName.STL_CACHE_GENERATE: "app.tasks.step_tasks.process_step_file", # ── Thumbnail generation ───────────────────────────────────────────── StepName.BLENDER_RENDER: "app.tasks.step_tasks.render_step_thumbnail", StepName.THUMBNAIL_SAVE: "app.tasks.step_tasks.render_step_thumbnail", # ── Order line stills & turntables ────────────────────────────────── StepName.BLENDER_STILL: "app.domains.rendering.tasks.render_order_line_still_task", StepName.BLENDER_TURNTABLE: "app.domains.rendering.tasks.render_turntable_task", # ── Asset export ───────────────────────────────────────────────────── StepName.EXPORT_BLEND: "app.domains.rendering.tasks.export_blend_for_order_line_task", # ── Steps without a dedicated standalone task (no mapping) ─────────── # StepName.GLB_BBOX — computed inline inside process_step_file # StepName.MATERIAL_MAP_RESOLVE — computed inline inside render tasks # StepName.AUTO_POPULATE_MATERIALS — computed inline inside process_step_file # StepName.THREEJS_RENDER — no standalone task exists yet # StepName.ORDER_LINE_SETUP — computed inline inside render_order_line_task # StepName.RESOLVE_TEMPLATE — computed inline inside render_order_line_task # StepName.OUTPUT_SAVE — handled via publish_asset after render tasks # StepName.NOTIFY — emitted inline via notification_service } def prepare_workflow_context( workflow_config: dict, *, context_id: str, execution_mode: WorkflowExecutionMode = "graph", workflow_run_id: uuid.UUID | None = None, ) -> WorkflowContext: """Validate workflow config and build the runtime context.""" config = WorkflowConfig.model_validate(workflow_config) ordered_nodes = _topological_sort(config) return WorkflowContext( context_id=context_id, execution_mode=execution_mode, workflow_run_id=workflow_run_id, ordered_nodes=ordered_nodes, ) def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult: """Execute prepared workflow nodes in topological order as individual Celery tasks.""" from app.tasks.celery_app import celery_app task_ids: list[str] = [] node_task_ids: dict[str, str] = {} skipped_node_ids: list[str] = [] for node in context.ordered_nodes: task_name = STEP_TASK_MAP.get(node.step) if task_name is None: logger.warning( "[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r", node.step, context.execution_mode, node.id, ) skipped_node_ids.append(node.id) continue result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params) task_ids.append(result.id) node_task_ids[node.id] = result.id logger.info( "[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s", node.id, node.step, context.execution_mode, context.workflow_run_id, result.id, ) logger.info( "[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s", len(task_ids), context.context_id, ) return WorkflowDispatchResult( context=context, task_ids=task_ids, node_task_ids=node_task_ids, skipped_node_ids=skipped_node_ids, ) def dispatch_workflow_with_context( workflow_config: dict, *, context_id: str, execution_mode: WorkflowExecutionMode = "graph", workflow_run_id: uuid.UUID | None = None, ) -> WorkflowDispatchResult: """Execute workflow nodes in topological order as individual Celery tasks. Args: workflow_config: Raw dict from WorkflowDefinition.config JSONB. context_id: UUID string of the entity being processed. Depending on the StepName, this is interpreted as a ``cad_file_id`` (for STEP processing / thumbnail tasks) or an ``order_line_id`` (for render tasks). Returns: Dispatch result including runtime context, dispatched task IDs and skipped nodes. Raises: pydantic.ValidationError: if *workflow_config* does not conform to WorkflowConfig schema. ValueError: if the node graph contains a cycle. """ context = prepare_workflow_context( workflow_config, context_id=context_id, execution_mode=execution_mode, workflow_run_id=workflow_run_id, ) return dispatch_prepared_workflow(context) def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]: """Backward-compatible wrapper returning only dispatched Celery task IDs.""" return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _topological_sort(config: WorkflowConfig) -> list[WorkflowNode]: """Return nodes in dependency order using Kahn's algorithm. Nodes with no incoming edges are processed first. If the graph contains a cycle a ``ValueError`` is raised (the caller should treat this as a bad workflow config). """ node_map: dict[str, WorkflowNode] = {n.id: n for n in config.nodes} in_degree: dict[str, int] = {n.id: 0 for n in config.nodes} adjacency: dict[str, list[str]] = {n.id: [] for n in config.nodes} for edge in config.edges: adjacency[edge.from_node].append(edge.to_node) in_degree[edge.to_node] += 1 queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0) result: list[WorkflowNode] = [] while queue: nid = queue.popleft() result.append(node_map[nid]) for neighbor in adjacency[nid]: in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(result) != len(config.nodes): raise ValueError( "Workflow config contains a cycle — topological sort failed. " f"Processed {len(result)}/{len(config.nodes)} nodes." ) return result