Files
HartOMat/backend/app/domains/rendering/workflow_executor.py
T

268 lines
10 KiB
Python

"""Execute a validated WorkflowDefinition config as individual Celery tasks.
Usage::
from app.domains.rendering.workflow_executor import dispatch_workflow
task_ids = dispatch_workflow(workflow_definition.config, context_id=str(order_line_id))
The function:
1. Validates the raw config dict via WorkflowConfig.
2. Sorts nodes into dependency order (topological sort, Kahn's algorithm).
3. Maps each StepName to the corresponding Celery task name string.
4. Sends each task independently (not chained) so that every node gets its own
Celery task ID and can be tracked individually.
Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a
warning — this lets you add new StepName values to the enum without breaking
existing dispatchers.
"""
from __future__ import annotations
import logging
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Literal
from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowEdge, WorkflowNode
from app.core.process_steps import StepName
logger = logging.getLogger(__name__)
WorkflowExecutionMode = Literal["legacy", "graph", "shadow"]
@dataclass(slots=True)
class WorkflowContext:
context_id: str
execution_mode: WorkflowExecutionMode
workflow_run_id: uuid.UUID | None = None
ordered_nodes: list[WorkflowNode] = field(default_factory=list)
edges: list[WorkflowEdge] = field(default_factory=list)
@dataclass(slots=True)
class WorkflowTaskDispatchSpec:
node_id: str
task_name: str
args: list[str]
kwargs: dict
task_id: str
queue: str | None = None
@dataclass(slots=True)
class WorkflowDispatchResult:
context: WorkflowContext
task_ids: list[str]
node_task_ids: dict[str, str]
skipped_node_ids: list[str]
task_specs: list[WorkflowTaskDispatchSpec] = field(default_factory=list)
class WorkflowTaskSubmissionError(RuntimeError):
def __init__(self, message: str, *, submitted_task_ids: list[str] | None = None) -> None:
super().__init__(message)
self.submitted_task_ids = list(submitted_task_ids or [])
def submit_prepared_workflow_tasks(dispatch_result: WorkflowDispatchResult) -> None:
"""Submit pre-built Celery tasks after DB state has been committed."""
from app.tasks.celery_app import celery_app
submitted_task_ids: list[str] = []
for spec in dispatch_result.task_specs:
task_options: dict[str, str] = {"task_id": spec.task_id}
if spec.queue:
task_options["queue"] = spec.queue
try:
celery_app.send_task(
spec.task_name,
args=spec.args,
kwargs=spec.kwargs,
**task_options,
)
except Exception as exc:
raise WorkflowTaskSubmissionError(
f"Failed to submit workflow task for node '{spec.node_id}': {exc}",
submitted_task_ids=submitted_task_ids,
) from exc
submitted_task_ids.append(spec.task_id)
# ---------------------------------------------------------------------------
# Step → Celery task name mapping
#
# Values are the full dotted task name strings passed to celery_app.send_task().
# Keep this in sync with the @celery_app.task(name=...) decorators in the
# various tasks modules.
# ---------------------------------------------------------------------------
STEP_TASK_MAP: dict[StepName, str] = {
# ── STEP file processing ─────────────────────────────────────────────
StepName.RESOLVE_STEP_PATH: "app.tasks.step_tasks.process_step_file",
StepName.OCC_OBJECT_EXTRACT: "app.tasks.step_tasks.process_step_file",
StepName.OCC_GLB_EXPORT: "app.tasks.step_tasks.generate_gltf_geometry_task",
StepName.STL_CACHE_GENERATE: "app.tasks.step_tasks.process_step_file",
# ── Thumbnail generation ─────────────────────────────────────────────
StepName.BLENDER_RENDER: "app.tasks.step_tasks.render_step_thumbnail",
StepName.THUMBNAIL_SAVE: "app.tasks.step_tasks.render_graph_thumbnail",
# ── Order line stills & turntables ──────────────────────────────────
StepName.BLENDER_STILL: "app.domains.rendering.tasks.render_order_line_still_task",
StepName.BLENDER_TURNTABLE: "app.domains.rendering.tasks.render_turntable_task",
# ── Asset export ─────────────────────────────────────────────────────
StepName.EXPORT_BLEND: "app.domains.rendering.tasks.export_blend_for_order_line_task",
# ── Steps without a dedicated standalone task (no mapping) ───────────
# StepName.GLB_BBOX — computed inline inside process_step_file
# StepName.MATERIAL_MAP_RESOLVE — computed inline inside render tasks
# StepName.AUTO_POPULATE_MATERIALS — computed inline inside process_step_file
# StepName.THREEJS_RENDER — no standalone task exists yet
# StepName.ORDER_LINE_SETUP — computed inline inside render_order_line_task
# StepName.RESOLVE_TEMPLATE — computed inline inside render_order_line_task
# StepName.OUTPUT_SAVE — handled via publish_asset after render tasks
# StepName.NOTIFY — emitted inline via notification_service
}
def prepare_workflow_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowContext:
"""Validate workflow config and build the runtime context."""
config = WorkflowConfig.model_validate(workflow_config)
ordered_nodes = _topological_sort(config)
return WorkflowContext(
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
ordered_nodes=ordered_nodes,
edges=list(config.edges),
)
def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult:
"""Execute prepared workflow nodes in topological order as individual Celery tasks."""
from app.tasks.celery_app import celery_app
task_ids: list[str] = []
node_task_ids: dict[str, str] = {}
skipped_node_ids: list[str] = []
for node in context.ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r",
node.step,
context.execution_mode,
node.id,
)
skipped_node_ids.append(node.id)
continue
result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params)
task_ids.append(result.id)
node_task_ids[node.id] = result.id
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
node.id,
node.step,
context.execution_mode,
context.workflow_run_id,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context.context_id,
)
return WorkflowDispatchResult(
context=context,
task_ids=task_ids,
node_task_ids=node_task_ids,
skipped_node_ids=skipped_node_ids,
)
def dispatch_workflow_with_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowDispatchResult:
"""Execute workflow nodes in topological order as individual Celery tasks.
Args:
workflow_config: Raw dict from WorkflowDefinition.config JSONB.
context_id: UUID string of the entity being processed. Depending on
the StepName, this is interpreted as a ``cad_file_id`` (for STEP
processing / thumbnail tasks) or an ``order_line_id`` (for render
tasks).
Returns:
Dispatch result including runtime context, dispatched task IDs and skipped nodes.
Raises:
pydantic.ValidationError: if *workflow_config* does not conform to
WorkflowConfig schema.
ValueError: if the node graph contains a cycle.
"""
context = prepare_workflow_context(
workflow_config,
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
)
return dispatch_prepared_workflow(context)
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
"""Backward-compatible wrapper returning only dispatched Celery task IDs."""
return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _topological_sort(config: WorkflowConfig) -> list[WorkflowNode]:
"""Return nodes in dependency order using Kahn's algorithm.
Nodes with no incoming edges are processed first. If the graph contains a
cycle a ``ValueError`` is raised (the caller should treat this as a bad
workflow config).
"""
node_map: dict[str, WorkflowNode] = {n.id: n for n in config.nodes}
in_degree: dict[str, int] = {n.id: 0 for n in config.nodes}
adjacency: dict[str, list[str]] = {n.id: [] for n in config.nodes}
for edge in config.edges:
adjacency[edge.from_node].append(edge.to_node)
in_degree[edge.to_node] += 1
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
result: list[WorkflowNode] = []
while queue:
nid = queue.popleft()
result.append(node_map[nid])
for neighbor in adjacency[nid]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if len(result) != len(config.nodes):
raise ValueError(
"Workflow config contains a cycle — topological sort failed. "
f"Processed {len(result)}/{len(config.nodes)} nodes."
)
return result