diff --git a/backend/app/domains/rendering/dispatch_service.py b/backend/app/domains/rendering/dispatch_service.py index ae7a288..ef630f5 100644 --- a/backend/app/domains/rendering/dispatch_service.py +++ b/backend/app/domains/rendering/dispatch_service.py @@ -13,7 +13,6 @@ task-dispatch logic in app.services.render_dispatcher (legacy path). from __future__ import annotations import logging -from datetime import datetime logger = logging.getLogger(__name__) @@ -32,7 +31,9 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict: from app.config import settings from app.domains.orders.models import OrderLine - from app.domains.rendering.models import OutputType, WorkflowDefinition, WorkflowRun + from app.domains.rendering.models import OutputType, WorkflowDefinition + from app.domains.rendering.workflow_executor import prepare_workflow_context + from app.domains.rendering.workflow_run_service import create_workflow_run, mark_workflow_run_failed engine = create_engine( settings.database_url.replace("+asyncpg", ""), @@ -96,6 +97,22 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict: workflow_type, ) + try: + workflow_context = prepare_workflow_context( + wf_def.config, + context_id=order_line_id, + execution_mode="legacy", + ) + except Exception as exc: + logger.warning( + "order_line %s: workflow_definition_id %s failed runtime preparation (%s), " + "falling back to legacy dispatch", + order_line_id, + wf_def.id, + exc, + ) + return _legacy_dispatch(order_line_id) + # For turntable workflows: resolve step_path + output_dir from the order line at runtime if workflow_type == "turntable" and ("step_path" not in params or "output_dir" not in params): from app.domains.products.models import CadFile as _CadFile @@ -120,19 +137,43 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict: str(_Path(_cfg.upload_dir) / "renders" / str(line.id)), ) - from app.domains.rendering.workflow_builder import dispatch_workflow - celery_task_id = dispatch_workflow(workflow_type, order_line_id, params) + run = None + try: + run = create_workflow_run( + session, + workflow_def_id=wf_def.id, + order_line_id=line.id, + workflow_context=workflow_context, + ) + session.commit() + except Exception as exc: + session.rollback() + logger.warning( + "order_line %s: failed to create workflow run for workflow_definition_id %s (%s), " + "falling back to legacy dispatch", + order_line_id, + wf_def.id, + exc, + ) + return _legacy_dispatch(order_line_id) - # Persist a WorkflowRun record - run = WorkflowRun( - workflow_def_id=wf_def.id, - order_line_id=line.id, - celery_task_id=celery_task_id, - status="pending", - started_at=datetime.utcnow(), - ) - session.add(run) - session.commit() + from app.domains.rendering.workflow_builder import dispatch_workflow + + try: + celery_task_id = dispatch_workflow(workflow_type, order_line_id, params) + run.celery_task_id = celery_task_id + session.commit() + except Exception as exc: + session.rollback() + session.add(run) + mark_workflow_run_failed(run, str(exc)) + session.commit() + logger.exception( + "order_line %s: workflow dispatch via definition %s failed, falling back to legacy dispatch", + order_line_id, + wf_def.id, + ) + return _legacy_dispatch(order_line_id) return { "backend": "workflow", diff --git a/backend/app/domains/rendering/workflow_executor.py b/backend/app/domains/rendering/workflow_executor.py index f2f630e..9a9b137 100644 --- a/backend/app/domains/rendering/workflow_executor.py +++ b/backend/app/domains/rendering/workflow_executor.py @@ -17,8 +17,13 @@ Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a warning — this lets you add new StepName values to the enum without breaking existing dispatchers. """ +from __future__ import annotations + import logging +import uuid from collections import deque +from dataclasses import dataclass, field +from typing import Literal from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode from app.core.process_steps import StepName @@ -26,6 +31,25 @@ from app.core.process_steps import StepName logger = logging.getLogger(__name__) +WorkflowExecutionMode = Literal["legacy", "graph", "shadow"] + + +@dataclass(slots=True) +class WorkflowContext: + context_id: str + execution_mode: WorkflowExecutionMode + workflow_run_id: uuid.UUID | None = None + ordered_nodes: list[WorkflowNode] = field(default_factory=list) + + +@dataclass(slots=True) +class WorkflowDispatchResult: + context: WorkflowContext + task_ids: list[str] + node_task_ids: dict[str, str] + skipped_node_ids: list[str] + + # --------------------------------------------------------------------------- # Step → Celery task name mapping # @@ -59,7 +83,76 @@ STEP_TASK_MAP: dict[StepName, str] = { } -def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]: +def prepare_workflow_context( + workflow_config: dict, + *, + context_id: str, + execution_mode: WorkflowExecutionMode = "graph", + workflow_run_id: uuid.UUID | None = None, +) -> WorkflowContext: + """Validate workflow config and build the runtime context.""" + config = WorkflowConfig.model_validate(workflow_config) + ordered_nodes = _topological_sort(config) + return WorkflowContext( + context_id=context_id, + execution_mode=execution_mode, + workflow_run_id=workflow_run_id, + ordered_nodes=ordered_nodes, + ) + + +def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult: + """Execute prepared workflow nodes in topological order as individual Celery tasks.""" + from app.tasks.celery_app import celery_app + + task_ids: list[str] = [] + node_task_ids: dict[str, str] = {} + skipped_node_ids: list[str] = [] + + for node in context.ordered_nodes: + task_name = STEP_TASK_MAP.get(node.step) + if task_name is None: + logger.warning( + "[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r", + node.step, + context.execution_mode, + node.id, + ) + skipped_node_ids.append(node.id) + continue + + result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params) + task_ids.append(result.id) + node_task_ids[node.id] = result.id + logger.info( + "[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s", + node.id, + node.step, + context.execution_mode, + context.workflow_run_id, + result.id, + ) + + logger.info( + "[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s", + len(task_ids), + context.context_id, + ) + return WorkflowDispatchResult( + context=context, + task_ids=task_ids, + node_task_ids=node_task_ids, + skipped_node_ids=skipped_node_ids, + ) + + +def dispatch_workflow_with_context( + workflow_config: dict, + *, + context_id: str, + execution_mode: WorkflowExecutionMode = "graph", + workflow_run_id: uuid.UUID | None = None, +) -> WorkflowDispatchResult: """Execute workflow nodes in topological order as individual Celery tasks. Args: @@ -70,47 +163,25 @@ def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]: tasks). Returns: - List of dispatched Celery task ID strings (one per mapped node). + Dispatch result including runtime context, dispatched task IDs and skipped nodes. Raises: pydantic.ValidationError: if *workflow_config* does not conform to WorkflowConfig schema. ValueError: if the node graph contains a cycle. """ - from app.tasks.celery_app import celery_app - - # Validate + parse config - config = WorkflowConfig.model_validate(workflow_config) - - # Topological sort so that dependency nodes dispatch first - ordered_nodes = _topological_sort(config) - - task_ids: list[str] = [] - for node in ordered_nodes: - task_name = STEP_TASK_MAP.get(node.step) - if task_name is None: - logger.warning( - "[WORKFLOW] No Celery task mapping for step %r — skipping node %r", - node.step, - node.id, - ) - continue - - result = celery_app.send_task(task_name, args=[context_id], kwargs=node.params) - task_ids.append(result.id) - logger.info( - "[WORKFLOW] Dispatched node %r (step=%s) → Celery task %s", - node.id, - node.step, - result.id, - ) - - logger.info( - "[WORKFLOW] dispatch_workflow complete: %d task(s) dispatched for context %s", - len(task_ids), - context_id, + context = prepare_workflow_context( + workflow_config, + context_id=context_id, + execution_mode=execution_mode, + workflow_run_id=workflow_run_id, ) - return task_ids + return dispatch_prepared_workflow(context) + + +def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]: + """Backward-compatible wrapper returning only dispatched Celery task IDs.""" + return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids # --------------------------------------------------------------------------- diff --git a/backend/app/domains/rendering/workflow_router.py b/backend/app/domains/rendering/workflow_router.py index c315db4..4dd1948 100644 --- a/backend/app/domains/rendering/workflow_router.py +++ b/backend/app/domains/rendering/workflow_router.py @@ -200,6 +200,9 @@ async def list_workflow_runs( class WorkflowDispatchResponse(BaseModel): + workflow_run: WorkflowRunOut + context_id: str + execution_mode: str dispatched: int task_ids: list[str] @@ -225,7 +228,15 @@ async def dispatch_workflow_endpoint( the caller can track progress. """ from pydantic import ValidationError as _ValidationError - from app.domains.rendering.workflow_executor import dispatch_workflow + from app.domains.rendering.workflow_executor import ( + dispatch_prepared_workflow, + prepare_workflow_context, + ) + from app.domains.rendering.workflow_run_service import ( + apply_graph_dispatch_result, + create_workflow_run, + mark_workflow_run_failed, + ) result = await db.execute( select(WorkflowDefinition).where(WorkflowDefinition.id == workflow_id) @@ -237,10 +248,59 @@ async def dispatch_workflow_endpoint( raise HTTPException(status_code=400, detail="Workflow has no config") try: - task_ids = dispatch_workflow(wf.config, context_id) + workflow_context = prepare_workflow_context( + wf.config, + context_id=context_id, + execution_mode="graph", + ) except _ValidationError as exc: raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}") except ValueError as exc: raise HTTPException(status_code=422, detail=str(exc)) - return WorkflowDispatchResponse(dispatched=len(task_ids), task_ids=task_ids) + run_id = await db.run_sync( + lambda sync_session: create_workflow_run( + sync_session, + workflow_def_id=wf.id, + order_line_id=None, + workflow_context=workflow_context, + ).id + ) + await db.commit() + + try: + dispatch_result = dispatch_prepared_workflow(workflow_context) + except Exception as exc: + failed_result = await db.execute( + select(WorkflowRun) + .where(WorkflowRun.id == run_id) + .options(selectinload(WorkflowRun.node_results)) + ) + failed_run = failed_result.scalar_one() + mark_workflow_run_failed(failed_run, str(exc)) + await db.commit() + raise + + run_result = await db.execute( + select(WorkflowRun) + .where(WorkflowRun.id == run_id) + .options(selectinload(WorkflowRun.node_results)) + ) + run = run_result.scalar_one() + apply_graph_dispatch_result(run, workflow_context, dispatch_result) + await db.commit() + + refreshed_result = await db.execute( + select(WorkflowRun) + .where(WorkflowRun.id == run_id) + .options(selectinload(WorkflowRun.node_results)) + ) + refreshed_run = refreshed_result.scalar_one() + + return WorkflowDispatchResponse( + workflow_run=refreshed_run, + context_id=context_id, + execution_mode=workflow_context.execution_mode, + dispatched=len(dispatch_result.task_ids), + task_ids=dispatch_result.task_ids, + ) diff --git a/backend/app/domains/rendering/workflow_run_service.py b/backend/app/domains/rendering/workflow_run_service.py new file mode 100644 index 0000000..d44e97f --- /dev/null +++ b/backend/app/domains/rendering/workflow_run_service.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun +from app.domains.rendering.workflow_executor import WorkflowContext, WorkflowDispatchResult + + +def create_workflow_run( + session: Session, + *, + workflow_def_id, + order_line_id, + workflow_context: WorkflowContext, +) -> WorkflowRun: + run = WorkflowRun( + workflow_def_id=workflow_def_id, + order_line_id=order_line_id, + status="pending", + started_at=datetime.utcnow(), + ) + session.add(run) + session.flush() + + workflow_context.workflow_run_id = run.id + for node in workflow_context.ordered_nodes: + metadata = {"step": node.step.value} + if node.ui and node.ui.label: + metadata["label"] = node.ui.label + session.add( + WorkflowNodeResult( + run_id=run.id, + node_name=node.id, + status="pending", + output=metadata, + ) + ) + + session.flush() + return run + + +def apply_graph_dispatch_result( + run: WorkflowRun, + workflow_context: WorkflowContext, + dispatch_result: WorkflowDispatchResult, +) -> None: + node_map = {node.id: node for node in workflow_context.ordered_nodes} + results_by_name = {node_result.node_name: node_result for node_result in run.node_results} + + run.celery_task_id = dispatch_result.task_ids[0] if dispatch_result.task_ids else None + + for node_id, node_result in results_by_name.items(): + node = node_map.get(node_id) + if node is None: + continue + + metadata = dict(node_result.output or {}) + metadata.setdefault("step", node.step.value) + if node.ui and node.ui.label: + metadata.setdefault("label", node.ui.label) + + task_id = dispatch_result.node_task_ids.get(node_id) + if task_id is not None: + node_result.status = "queued" + metadata["task_id"] = task_id + node_result.output = metadata + continue + + if node_id in dispatch_result.skipped_node_ids: + node_result.status = "skipped" + node_result.output = metadata + node_result.log = f"No Celery task mapping for step '{node.step.value}'" + + +def mark_workflow_run_failed(run: WorkflowRun, error_message: str) -> None: + run.status = "failed" + run.completed_at = datetime.utcnow() + run.error_message = error_message[:2000] diff --git a/backend/tests/domains/test_workflow_dispatch_service.py b/backend/tests/domains/test_workflow_dispatch_service.py new file mode 100644 index 0000000..c1fee7c --- /dev/null +++ b/backend/tests/domains/test_workflow_dispatch_service.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import uuid + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from app.config import settings +from app.domains.orders.models import Order, OrderLine +from app.domains.products.models import Product +from app.domains.rendering.dispatch_service import dispatch_render_with_workflow +from app.domains.rendering.models import OutputType, WorkflowDefinition, WorkflowRun +from app.domains.rendering.workflow_config_utils import build_preset_workflow_config + + +def _use_test_database(monkeypatch) -> None: + monkeypatch.setattr(settings, "postgres_host", "postgres") + monkeypatch.setattr(settings, "postgres_port", 5432) + monkeypatch.setattr(settings, "postgres_user", "hartomat") + monkeypatch.setattr(settings, "postgres_password", "hartomat") + monkeypatch.setattr(settings, "postgres_db", "hartomat_test") + + +async def _seed_order_line( + db, + admin_user, + *, + workflow_config: dict | None = None, +) -> dict[str, object]: + product = Product( + pim_id=f"PIM-{uuid.uuid4().hex[:8]}", + name="Workflow Test Product", + ) + output_type = OutputType( + name=f"Workflow Output {uuid.uuid4().hex[:8]}", + render_backend="auto", + ) + order = Order( + order_number=f"WF-{uuid.uuid4().hex[:10]}", + created_by=admin_user.id, + ) + db.add_all([product, output_type, order]) + await db.flush() + + workflow_definition = None + if workflow_config is not None: + workflow_definition = WorkflowDefinition( + name=f"Workflow {uuid.uuid4().hex[:8]}", + output_type_id=output_type.id, + config=workflow_config, + is_active=True, + ) + db.add(workflow_definition) + await db.flush() + output_type.workflow_definition_id = workflow_definition.id + + order_line = OrderLine( + order_id=order.id, + product_id=product.id, + output_type_id=output_type.id, + ) + db.add(order_line) + await db.commit() + + return { + "order_line": order_line, + "workflow_definition": workflow_definition, + "output_type": output_type, + } + + +@pytest.mark.asyncio +async def test_dispatch_render_with_workflow_falls_back_to_legacy_without_workflow_definition( + db, + admin_user, + monkeypatch, +): + _use_test_database(monkeypatch) + seeded = await _seed_order_line(db, admin_user) + + monkeypatch.setattr( + "app.domains.rendering.dispatch_service._legacy_dispatch", + lambda order_line_id: {"backend": "legacy", "order_line_id": order_line_id}, + ) + + result = dispatch_render_with_workflow(str(seeded["order_line"].id)) + + await db.rollback() + + assert result == { + "backend": "legacy", + "order_line_id": str(seeded["order_line"].id), + } + runs = (await db.execute(select(WorkflowRun))).scalars().all() + assert runs == [] + + +@pytest.mark.asyncio +async def test_dispatch_render_with_workflow_creates_run_and_node_results_for_preset_dispatch( + db, + admin_user, + monkeypatch, +): + _use_test_database(monkeypatch) + seeded = await _seed_order_line( + db, + admin_user, + workflow_config=build_preset_workflow_config("still", {"width": 1024, "height": 1024}), + ) + + monkeypatch.setattr( + "app.domains.rendering.workflow_builder.dispatch_workflow", + lambda workflow_type, order_line_id, params=None: "canvas-123", + ) + + result = dispatch_render_with_workflow(str(seeded["order_line"].id)) + + await db.rollback() + + run_result = await db.execute( + select(WorkflowRun) + .where(WorkflowRun.id == uuid.UUID(result["workflow_run_id"])) + .options(selectinload(WorkflowRun.node_results)) + ) + run = run_result.scalar_one() + + assert result["backend"] == "workflow" + assert result["workflow_type"] == "still" + assert result["celery_task_id"] == "canvas-123" + assert run.workflow_def_id == seeded["workflow_definition"].id + assert run.order_line_id == seeded["order_line"].id + assert run.celery_task_id == "canvas-123" + assert {node_result.node_name for node_result in run.node_results} == { + "setup", + "template", + "render", + "output", + } + assert all(node_result.status == "pending" for node_result in run.node_results) + + +@pytest.mark.asyncio +async def test_dispatch_render_with_workflow_falls_back_when_workflow_runtime_preparation_is_invalid( + db, + admin_user, + monkeypatch, +): + _use_test_database(monkeypatch) + seeded = await _seed_order_line( + db, + admin_user, + workflow_config={ + "version": 1, + "nodes": [ + {"id": "render", "step": "blender_still", "params": {}}, + ], + "edges": [ + {"from": "missing", "to": "render"}, + ], + }, + ) + + monkeypatch.setattr( + "app.domains.rendering.dispatch_service._legacy_dispatch", + lambda order_line_id: {"backend": "legacy", "order_line_id": order_line_id}, + ) + + result = dispatch_render_with_workflow(str(seeded["order_line"].id)) + + await db.rollback() + + assert result == { + "backend": "legacy", + "order_line_id": str(seeded["order_line"].id), + } + runs = (await db.execute(select(WorkflowRun))).scalars().all() + assert runs == [] + + +@pytest.mark.asyncio +async def test_workflow_dispatch_endpoint_returns_workflow_run_with_node_results( + client, + db, + auth_headers, + monkeypatch, +): + workflow_definition = WorkflowDefinition( + name=f"Dispatch Workflow {uuid.uuid4().hex[:8]}", + config=build_preset_workflow_config("still_with_exports", {"width": 640, "height": 640}), + is_active=True, + ) + db.add(workflow_definition) + await db.commit() + await db.refresh(workflow_definition) + + calls: list[tuple[str, list[str], dict]] = [] + + def _fake_send_task(task_name: str, args: list[str], kwargs: dict): + calls.append((task_name, args, kwargs)) + return type("Result", (), {"id": f"task-{len(calls)}"})() + + context_id = str(uuid.uuid4()) + monkeypatch.setattr("app.tasks.celery_app.celery_app.send_task", _fake_send_task) + response = await client.post( + f"/api/workflows/{workflow_definition.id}/dispatch", + params={"context_id": context_id}, + headers=auth_headers, + ) + + assert response.status_code == 200 + body = response.json() + + assert body["context_id"] == context_id + assert body["execution_mode"] == "graph" + assert body["dispatched"] == 2 + assert body["task_ids"] == ["task-1", "task-2"] + assert calls == [ + ( + "app.domains.rendering.tasks.render_order_line_still_task", + [context_id], + {"width": 640, "height": 640}, + ), + ( + "app.domains.rendering.tasks.export_blend_for_order_line_task", + [context_id], + {}, + ), + ] + + node_results = {node["node_name"]: node for node in body["workflow_run"]["node_results"]} + assert body["workflow_run"]["status"] == "pending" + assert body["workflow_run"]["celery_task_id"] == "task-1" + assert node_results["render"]["status"] == "queued" + assert node_results["render"]["output"]["task_id"] == "task-1" + assert node_results["blend"]["status"] == "queued" + assert node_results["blend"]["output"]["task_id"] == "task-2" + assert node_results["setup"]["status"] == "skipped" + assert node_results["template"]["status"] == "skipped" + assert node_results["output"]["status"] == "skipped" diff --git a/docs/workflows/WORKFLOW_DELIVERY_CHECKLIST.md b/docs/workflows/WORKFLOW_DELIVERY_CHECKLIST.md index 06ae9a4..a21e4b2 100644 --- a/docs/workflows/WORKFLOW_DELIVERY_CHECKLIST.md +++ b/docs/workflows/WORKFLOW_DELIVERY_CHECKLIST.md @@ -26,10 +26,11 @@ ### Phase 4 -- [ ] Workflow context introduced +- [x] Workflow context introduced - [ ] Node outputs are persisted and reusable - [ ] Graph runtime supports legacy fallback - [ ] `legacy`, `graph`, and `shadow` modes exist +- Progress: Phase 4 foundation now persists `WorkflowRun` and initial `WorkflowNodeResult` records for both linked workflow dispatch and `/api/workflows/{id}/dispatch`, while keeping the legacy preset dispatcher as the safe default fallback. ### Phase 5 diff --git a/docs/workflows/WORKFLOW_IMPLEMENTATION_BACKLOG.md b/docs/workflows/WORKFLOW_IMPLEMENTATION_BACKLOG.md index 86692ab..f548fab 100644 --- a/docs/workflows/WORKFLOW_IMPLEMENTATION_BACKLOG.md +++ b/docs/workflows/WORKFLOW_IMPLEMENTATION_BACKLOG.md @@ -70,7 +70,7 @@ ### Tickets -- `E4-T1` Introduce `WorkflowContext`. +- `E4-T1` Introduce `WorkflowContext`. `completed` - `E4-T2` Refactor executor to process nodes against context and node outputs. - `E4-T3` Persist node-level run records, logs, timings, and outputs. - `E4-T4` Support retry and failure policies.