diff --git a/backend/app/domains/rendering/workflow_executor.py b/backend/app/domains/rendering/workflow_executor.py new file mode 100644 index 0000000..764269f --- /dev/null +++ b/backend/app/domains/rendering/workflow_executor.py @@ -0,0 +1,155 @@ +"""Execute a validated WorkflowDefinition config as individual Celery tasks. + +Usage:: + + from app.domains.rendering.workflow_executor import dispatch_workflow + + task_ids = dispatch_workflow(workflow_definition.config, context_id=str(order_line_id)) + +The function: +1. Validates the raw config dict via WorkflowConfig. +2. Sorts nodes into dependency order (topological sort, Kahn's algorithm). +3. Maps each StepName to the corresponding Celery task name string. +4. Sends each task independently (not chained) so that every node gets its own + Celery task ID and can be tracked individually. + +Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a +warning — this lets you add new StepName values to the enum without breaking +existing dispatchers. +""" +import logging +from collections import deque + +from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode +from app.core.process_steps import StepName + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Step → Celery task name mapping +# +# Values are the full dotted task name strings passed to celery_app.send_task(). +# Keep this in sync with the @celery_app.task(name=...) decorators in the +# various tasks modules. +# --------------------------------------------------------------------------- +STEP_TASK_MAP: dict[StepName, str] = { + # ── STEP file processing ───────────────────────────────────────────── + StepName.RESOLVE_STEP_PATH: "app.tasks.step_tasks.process_step_file", + StepName.OCC_OBJECT_EXTRACT: "app.tasks.step_tasks.process_step_file", + StepName.OCC_GLB_EXPORT: "app.tasks.step_tasks.generate_gltf_geometry_task", + StepName.STL_CACHE_GENERATE: "app.tasks.step_tasks.process_step_file", + # ── Thumbnail generation ───────────────────────────────────────────── + StepName.BLENDER_RENDER: "app.tasks.step_tasks.render_step_thumbnail", + StepName.THUMBNAIL_SAVE: "app.tasks.step_tasks.render_step_thumbnail", + # ── Order line stills & turntables ────────────────────────────────── + StepName.BLENDER_STILL: "app.domains.rendering.tasks.render_order_line_still_task", + StepName.BLENDER_TURNTABLE: "app.domains.rendering.tasks.render_turntable_task", + # ── GLB / asset export ─────────────────────────────────────────────── + StepName.EXPORT_GLB_GEOMETRY: "app.domains.rendering.tasks.export_gltf_for_order_line_task", + StepName.EXPORT_BLEND: "app.domains.rendering.tasks.export_blend_for_order_line_task", + # ── Steps without a dedicated standalone task (no mapping) ─────────── + # StepName.GLB_BBOX — computed inline inside process_step_file + # StepName.MATERIAL_MAP_RESOLVE — computed inline inside render tasks + # StepName.AUTO_POPULATE_MATERIALS — computed inline inside process_step_file + # StepName.THREEJS_RENDER — no standalone task exists yet + # StepName.ORDER_LINE_SETUP — computed inline inside render_order_line_task + # StepName.RESOLVE_TEMPLATE — computed inline inside render_order_line_task + # StepName.OUTPUT_SAVE — handled via publish_asset after render tasks + # StepName.EXPORT_GLB_PRODUCTION — app.tasks.step_tasks.generate_gltf_production_task + StepName.EXPORT_GLB_PRODUCTION: "app.tasks.step_tasks.generate_gltf_production_task", + # StepName.NOTIFY — emitted inline via notification_service +} + + +def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]: + """Execute workflow nodes in topological order as individual Celery tasks. + + Args: + workflow_config: Raw dict from WorkflowDefinition.config JSONB. + context_id: UUID string of the entity being processed. Depending on + the StepName, this is interpreted as a ``cad_file_id`` (for STEP + processing / thumbnail tasks) or an ``order_line_id`` (for render + tasks). + + Returns: + List of dispatched Celery task ID strings (one per mapped node). + + Raises: + pydantic.ValidationError: if *workflow_config* does not conform to + WorkflowConfig schema. + ValueError: if the node graph contains a cycle. + """ + from app.tasks.celery_app import celery_app + + # Validate + parse config + config = WorkflowConfig.model_validate(workflow_config) + + # Topological sort so that dependency nodes dispatch first + ordered_nodes = _topological_sort(config) + + task_ids: list[str] = [] + for node in ordered_nodes: + task_name = STEP_TASK_MAP.get(node.step) + if task_name is None: + logger.warning( + "[WORKFLOW] No Celery task mapping for step %r — skipping node %r", + node.step, + node.id, + ) + continue + + result = celery_app.send_task(task_name, args=[context_id], kwargs=node.params) + task_ids.append(result.id) + logger.info( + "[WORKFLOW] Dispatched node %r (step=%s) → Celery task %s", + node.id, + node.step, + result.id, + ) + + logger.info( + "[WORKFLOW] dispatch_workflow complete: %d task(s) dispatched for context %s", + len(task_ids), + context_id, + ) + return task_ids + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _topological_sort(config: WorkflowConfig) -> list[WorkflowNode]: + """Return nodes in dependency order using Kahn's algorithm. + + Nodes with no incoming edges are processed first. If the graph contains a + cycle a ``ValueError`` is raised (the caller should treat this as a bad + workflow config). + """ + node_map: dict[str, WorkflowNode] = {n.id: n for n in config.nodes} + in_degree: dict[str, int] = {n.id: 0 for n in config.nodes} + adjacency: dict[str, list[str]] = {n.id: [] for n in config.nodes} + + for edge in config.edges: + adjacency[edge.from_node].append(edge.to_node) + in_degree[edge.to_node] += 1 + + queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0) + result: list[WorkflowNode] = [] + + while queue: + nid = queue.popleft() + result.append(node_map[nid]) + for neighbor in adjacency[nid]: + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + if len(result) != len(config.nodes): + raise ValueError( + "Workflow config contains a cycle — topological sort failed. " + f"Processed {len(result)}/{len(config.nodes)} nodes." + ) + + return result diff --git a/backend/app/domains/rendering/workflow_router.py b/backend/app/domains/rendering/workflow_router.py index e9b251f..275e6ba 100644 --- a/backend/app/domains/rendering/workflow_router.py +++ b/backend/app/domains/rendering/workflow_router.py @@ -12,15 +12,15 @@ Endpoints: import uuid from typing import Literal -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, ValidationError from sqlalchemy import select from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.domains.auth.models import User -from app.utils.auth import get_current_user, require_admin, require_admin_or_pm +from app.utils.auth import get_current_user, require_admin, require_admin_or_pm, require_pm_or_above from app.domains.rendering.models import WorkflowDefinition, WorkflowRun from app.domains.rendering.schemas import ( WorkflowDefinitionCreate, @@ -28,6 +28,7 @@ from app.domains.rendering.schemas import ( WorkflowDefinitionOut, WorkflowRunOut, ) +from app.domains.rendering.workflow_schema import WorkflowConfig from app.core.process_steps import StepName @@ -142,6 +143,11 @@ async def create_workflow( _user: User = Depends(require_admin), db: AsyncSession = Depends(get_db), ): + if body.config: + try: + WorkflowConfig.model_validate(body.config) + except ValidationError as exc: + raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}") wf = WorkflowDefinition( name=body.name, output_type_id=body.output_type_id, @@ -171,6 +177,10 @@ async def update_workflow( if body.name is not None: wf.name = body.name if body.config is not None: + try: + WorkflowConfig.model_validate(body.config) + except ValidationError as exc: + raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}") wf.config = body.config if body.is_active is not None: wf.is_active = body.is_active @@ -216,3 +226,50 @@ async def list_workflow_runs( .order_by(WorkflowRun.created_at.desc()) ) return result.scalars().all() + + +class WorkflowDispatchResponse(BaseModel): + dispatched: int + task_ids: list[str] + + +@router.post("/{workflow_id}/dispatch", response_model=WorkflowDispatchResponse) +async def dispatch_workflow_endpoint( + workflow_id: uuid.UUID, + context_id: str = Query( + ..., + description=( + "UUID of the entity to process. " + "For STEP/thumbnail steps this is a cad_file_id; " + "for render steps this is an order_line_id." + ), + ), + _user: User = Depends(require_pm_or_above), + db: AsyncSession = Depends(get_db), +): + """Dispatch a workflow's steps as Celery tasks for a given context entity. + + Each node in the workflow config is dispatched as an individual Celery task + in topological (dependency) order. Returns the list of Celery task IDs so + the caller can track progress. + """ + from pydantic import ValidationError as _ValidationError + from app.domains.rendering.workflow_executor import dispatch_workflow + + result = await db.execute( + select(WorkflowDefinition).where(WorkflowDefinition.id == workflow_id) + ) + wf = result.scalar_one_or_none() + if not wf: + raise HTTPException(status_code=404, detail="Workflow definition not found") + if not wf.config: + raise HTTPException(status_code=400, detail="Workflow has no config") + + try: + task_ids = dispatch_workflow(wf.config, context_id) + except _ValidationError as exc: + raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}") + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) + + return WorkflowDispatchResponse(dispatched=len(task_ids), task_ids=task_ids) diff --git a/backend/app/domains/rendering/workflow_schema.py b/backend/app/domains/rendering/workflow_schema.py new file mode 100644 index 0000000..5cc20b2 --- /dev/null +++ b/backend/app/domains/rendering/workflow_schema.py @@ -0,0 +1,75 @@ +"""Pydantic schema for validated WorkflowDefinition.config JSONB. + +A workflow config is a versioned DAG description stored as JSONB. Before +being dispatched (or saved), the raw dict must pass this schema. + +Example config:: + + { + "version": 1, + "nodes": [ + {"id": "n1", "step": "resolve_step_path", "params": {}}, + {"id": "n2", "step": "blender_still", "params": {"engine": "cycles"}} + ], + "edges": [ + {"from": "n1", "to": "n2"} + ] + } +""" +from pydantic import BaseModel, Field, field_validator, model_validator + +from app.core.process_steps import StepName + + +class WorkflowNode(BaseModel): + id: str + step: StepName # validated against the StepName StrEnum + params: dict = {} + + +class WorkflowEdge(BaseModel): + # "from" is a Python keyword, so we alias it. + from_node: str = Field(alias="from") + to_node: str = Field(alias="to") + + model_config = {"populate_by_name": True} + + +class WorkflowConfig(BaseModel): + version: int = 1 + nodes: list[WorkflowNode] + edges: list[WorkflowEdge] = [] + + @field_validator("nodes") + @classmethod + def nodes_not_empty(cls, v: list[WorkflowNode]) -> list[WorkflowNode]: + if not v: + raise ValueError("workflow must have at least one node") + return v + + @field_validator("edges") + @classmethod + def edges_reference_valid_nodes( + cls, edges: list[WorkflowEdge], info + ) -> list[WorkflowEdge]: + if "nodes" in info.data: + node_ids = {n.id for n in info.data["nodes"]} + for edge in edges: + if edge.from_node not in node_ids: + raise ValueError( + f"edge references unknown node id: {edge.from_node!r}" + ) + if edge.to_node not in node_ids: + raise ValueError( + f"edge references unknown node id: {edge.to_node!r}" + ) + return edges + + @model_validator(mode="after") + def node_ids_are_unique(self) -> "WorkflowConfig": + seen: set[str] = set() + for node in self.nodes: + if node.id in seen: + raise ValueError(f"duplicate node id: {node.id!r}") + seen.add(node.id) + return self