268 lines
10 KiB
Python
268 lines
10 KiB
Python
"""Execute a validated WorkflowDefinition config as individual Celery tasks.
|
|
|
|
Usage::
|
|
|
|
from app.domains.rendering.workflow_executor import dispatch_workflow
|
|
|
|
task_ids = dispatch_workflow(workflow_definition.config, context_id=str(order_line_id))
|
|
|
|
The function:
|
|
1. Validates the raw config dict via WorkflowConfig.
|
|
2. Sorts nodes into dependency order (topological sort, Kahn's algorithm).
|
|
3. Maps each StepName to the corresponding Celery task name string.
|
|
4. Sends each task independently (not chained) so that every node gets its own
|
|
Celery task ID and can be tracked individually.
|
|
|
|
Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a
|
|
warning — this lets you add new StepName values to the enum without breaking
|
|
existing dispatchers.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from collections import deque
|
|
from dataclasses import dataclass, field
|
|
from typing import Literal
|
|
|
|
from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowEdge, WorkflowNode
|
|
from app.core.process_steps import StepName
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
WorkflowExecutionMode = Literal["legacy", "graph", "shadow"]
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class WorkflowContext:
|
|
context_id: str
|
|
execution_mode: WorkflowExecutionMode
|
|
workflow_run_id: uuid.UUID | None = None
|
|
ordered_nodes: list[WorkflowNode] = field(default_factory=list)
|
|
edges: list[WorkflowEdge] = field(default_factory=list)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class WorkflowTaskDispatchSpec:
|
|
node_id: str
|
|
task_name: str
|
|
args: list[str]
|
|
kwargs: dict
|
|
task_id: str
|
|
queue: str | None = None
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class WorkflowDispatchResult:
|
|
context: WorkflowContext
|
|
task_ids: list[str]
|
|
node_task_ids: dict[str, str]
|
|
skipped_node_ids: list[str]
|
|
task_specs: list[WorkflowTaskDispatchSpec] = field(default_factory=list)
|
|
|
|
|
|
class WorkflowTaskSubmissionError(RuntimeError):
|
|
def __init__(self, message: str, *, submitted_task_ids: list[str] | None = None) -> None:
|
|
super().__init__(message)
|
|
self.submitted_task_ids = list(submitted_task_ids or [])
|
|
|
|
|
|
def submit_prepared_workflow_tasks(dispatch_result: WorkflowDispatchResult) -> None:
|
|
"""Submit pre-built Celery tasks after DB state has been committed."""
|
|
from app.tasks.celery_app import celery_app
|
|
|
|
submitted_task_ids: list[str] = []
|
|
for spec in dispatch_result.task_specs:
|
|
task_options: dict[str, str] = {"task_id": spec.task_id}
|
|
if spec.queue:
|
|
task_options["queue"] = spec.queue
|
|
try:
|
|
celery_app.send_task(
|
|
spec.task_name,
|
|
args=spec.args,
|
|
kwargs=spec.kwargs,
|
|
**task_options,
|
|
)
|
|
except Exception as exc:
|
|
raise WorkflowTaskSubmissionError(
|
|
f"Failed to submit workflow task for node '{spec.node_id}': {exc}",
|
|
submitted_task_ids=submitted_task_ids,
|
|
) from exc
|
|
submitted_task_ids.append(spec.task_id)
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Step → Celery task name mapping
|
|
#
|
|
# Values are the full dotted task name strings passed to celery_app.send_task().
|
|
# Keep this in sync with the @celery_app.task(name=...) decorators in the
|
|
# various tasks modules.
|
|
# ---------------------------------------------------------------------------
|
|
STEP_TASK_MAP: dict[StepName, str] = {
|
|
# ── STEP file processing ─────────────────────────────────────────────
|
|
StepName.RESOLVE_STEP_PATH: "app.tasks.step_tasks.process_step_file",
|
|
StepName.OCC_OBJECT_EXTRACT: "app.tasks.step_tasks.process_step_file",
|
|
StepName.OCC_GLB_EXPORT: "app.tasks.step_tasks.generate_gltf_geometry_task",
|
|
StepName.STL_CACHE_GENERATE: "app.tasks.step_tasks.process_step_file",
|
|
# ── Thumbnail generation ─────────────────────────────────────────────
|
|
StepName.BLENDER_RENDER: "app.tasks.step_tasks.render_step_thumbnail",
|
|
StepName.THUMBNAIL_SAVE: "app.tasks.step_tasks.render_graph_thumbnail",
|
|
# ── Order line stills & turntables ──────────────────────────────────
|
|
StepName.BLENDER_STILL: "app.domains.rendering.tasks.render_order_line_still_task",
|
|
StepName.BLENDER_TURNTABLE: "app.domains.rendering.tasks.render_turntable_task",
|
|
# ── Asset export ─────────────────────────────────────────────────────
|
|
StepName.EXPORT_BLEND: "app.domains.rendering.tasks.export_blend_for_order_line_task",
|
|
# ── Steps without a dedicated standalone task (no mapping) ───────────
|
|
# StepName.GLB_BBOX — computed inline inside process_step_file
|
|
# StepName.MATERIAL_MAP_RESOLVE — computed inline inside render tasks
|
|
# StepName.AUTO_POPULATE_MATERIALS — computed inline inside process_step_file
|
|
# StepName.THREEJS_RENDER — no standalone task exists yet
|
|
# StepName.ORDER_LINE_SETUP — computed inline inside render_order_line_task
|
|
# StepName.RESOLVE_TEMPLATE — computed inline inside render_order_line_task
|
|
# StepName.OUTPUT_SAVE — handled via publish_asset after render tasks
|
|
# StepName.NOTIFY — emitted inline via notification_service
|
|
}
|
|
|
|
|
|
def prepare_workflow_context(
|
|
workflow_config: dict,
|
|
*,
|
|
context_id: str,
|
|
execution_mode: WorkflowExecutionMode = "graph",
|
|
workflow_run_id: uuid.UUID | None = None,
|
|
) -> WorkflowContext:
|
|
"""Validate workflow config and build the runtime context."""
|
|
config = WorkflowConfig.model_validate(workflow_config)
|
|
ordered_nodes = _topological_sort(config)
|
|
return WorkflowContext(
|
|
context_id=context_id,
|
|
execution_mode=execution_mode,
|
|
workflow_run_id=workflow_run_id,
|
|
ordered_nodes=ordered_nodes,
|
|
edges=list(config.edges),
|
|
)
|
|
|
|
|
|
def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult:
|
|
"""Execute prepared workflow nodes in topological order as individual Celery tasks."""
|
|
from app.tasks.celery_app import celery_app
|
|
|
|
task_ids: list[str] = []
|
|
node_task_ids: dict[str, str] = {}
|
|
skipped_node_ids: list[str] = []
|
|
|
|
for node in context.ordered_nodes:
|
|
task_name = STEP_TASK_MAP.get(node.step)
|
|
if task_name is None:
|
|
logger.warning(
|
|
"[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r",
|
|
node.step,
|
|
context.execution_mode,
|
|
node.id,
|
|
)
|
|
skipped_node_ids.append(node.id)
|
|
continue
|
|
|
|
result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params)
|
|
task_ids.append(result.id)
|
|
node_task_ids[node.id] = result.id
|
|
logger.info(
|
|
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
|
|
node.id,
|
|
node.step,
|
|
context.execution_mode,
|
|
context.workflow_run_id,
|
|
result.id,
|
|
)
|
|
|
|
logger.info(
|
|
"[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s",
|
|
len(task_ids),
|
|
context.context_id,
|
|
)
|
|
return WorkflowDispatchResult(
|
|
context=context,
|
|
task_ids=task_ids,
|
|
node_task_ids=node_task_ids,
|
|
skipped_node_ids=skipped_node_ids,
|
|
)
|
|
|
|
|
|
def dispatch_workflow_with_context(
|
|
workflow_config: dict,
|
|
*,
|
|
context_id: str,
|
|
execution_mode: WorkflowExecutionMode = "graph",
|
|
workflow_run_id: uuid.UUID | None = None,
|
|
) -> WorkflowDispatchResult:
|
|
"""Execute workflow nodes in topological order as individual Celery tasks.
|
|
|
|
Args:
|
|
workflow_config: Raw dict from WorkflowDefinition.config JSONB.
|
|
context_id: UUID string of the entity being processed. Depending on
|
|
the StepName, this is interpreted as a ``cad_file_id`` (for STEP
|
|
processing / thumbnail tasks) or an ``order_line_id`` (for render
|
|
tasks).
|
|
|
|
Returns:
|
|
Dispatch result including runtime context, dispatched task IDs and skipped nodes.
|
|
|
|
Raises:
|
|
pydantic.ValidationError: if *workflow_config* does not conform to
|
|
WorkflowConfig schema.
|
|
ValueError: if the node graph contains a cycle.
|
|
"""
|
|
context = prepare_workflow_context(
|
|
workflow_config,
|
|
context_id=context_id,
|
|
execution_mode=execution_mode,
|
|
workflow_run_id=workflow_run_id,
|
|
)
|
|
return dispatch_prepared_workflow(context)
|
|
|
|
|
|
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
|
|
"""Backward-compatible wrapper returning only dispatched Celery task IDs."""
|
|
return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _topological_sort(config: WorkflowConfig) -> list[WorkflowNode]:
|
|
"""Return nodes in dependency order using Kahn's algorithm.
|
|
|
|
Nodes with no incoming edges are processed first. If the graph contains a
|
|
cycle a ``ValueError`` is raised (the caller should treat this as a bad
|
|
workflow config).
|
|
"""
|
|
node_map: dict[str, WorkflowNode] = {n.id: n for n in config.nodes}
|
|
in_degree: dict[str, int] = {n.id: 0 for n in config.nodes}
|
|
adjacency: dict[str, list[str]] = {n.id: [] for n in config.nodes}
|
|
|
|
for edge in config.edges:
|
|
adjacency[edge.from_node].append(edge.to_node)
|
|
in_degree[edge.to_node] += 1
|
|
|
|
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
|
|
result: list[WorkflowNode] = []
|
|
|
|
while queue:
|
|
nid = queue.popleft()
|
|
result.append(node_map[nid])
|
|
for neighbor in adjacency[nid]:
|
|
in_degree[neighbor] -= 1
|
|
if in_degree[neighbor] == 0:
|
|
queue.append(neighbor)
|
|
|
|
if len(result) != len(config.nodes):
|
|
raise ValueError(
|
|
"Workflow config contains a cycle — topological sort failed. "
|
|
f"Processed {len(result)}/{len(config.nodes)} nodes."
|
|
)
|
|
|
|
return result
|