feat: add workflow run dispatch foundation

This commit is contained in:
2026-04-07 10:11:46 +02:00
parent ab1b220e79
commit 6ad34ceed2
7 changed files with 548 additions and 54 deletions
@@ -13,7 +13,6 @@ task-dispatch logic in app.services.render_dispatcher (legacy path).
from __future__ import annotations from __future__ import annotations
import logging import logging
from datetime import datetime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,7 +31,9 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
from app.config import settings from app.config import settings
from app.domains.orders.models import OrderLine from app.domains.orders.models import OrderLine
from app.domains.rendering.models import OutputType, WorkflowDefinition, WorkflowRun from app.domains.rendering.models import OutputType, WorkflowDefinition
from app.domains.rendering.workflow_executor import prepare_workflow_context
from app.domains.rendering.workflow_run_service import create_workflow_run, mark_workflow_run_failed
engine = create_engine( engine = create_engine(
settings.database_url.replace("+asyncpg", ""), settings.database_url.replace("+asyncpg", ""),
@@ -96,6 +97,22 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
workflow_type, workflow_type,
) )
try:
workflow_context = prepare_workflow_context(
wf_def.config,
context_id=order_line_id,
execution_mode="legacy",
)
except Exception as exc:
logger.warning(
"order_line %s: workflow_definition_id %s failed runtime preparation (%s), "
"falling back to legacy dispatch",
order_line_id,
wf_def.id,
exc,
)
return _legacy_dispatch(order_line_id)
# For turntable workflows: resolve step_path + output_dir from the order line at runtime # For turntable workflows: resolve step_path + output_dir from the order line at runtime
if workflow_type == "turntable" and ("step_path" not in params or "output_dir" not in params): if workflow_type == "turntable" and ("step_path" not in params or "output_dir" not in params):
from app.domains.products.models import CadFile as _CadFile from app.domains.products.models import CadFile as _CadFile
@@ -120,19 +137,43 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
str(_Path(_cfg.upload_dir) / "renders" / str(line.id)), str(_Path(_cfg.upload_dir) / "renders" / str(line.id)),
) )
from app.domains.rendering.workflow_builder import dispatch_workflow run = None
celery_task_id = dispatch_workflow(workflow_type, order_line_id, params) try:
run = create_workflow_run(
session,
workflow_def_id=wf_def.id,
order_line_id=line.id,
workflow_context=workflow_context,
)
session.commit()
except Exception as exc:
session.rollback()
logger.warning(
"order_line %s: failed to create workflow run for workflow_definition_id %s (%s), "
"falling back to legacy dispatch",
order_line_id,
wf_def.id,
exc,
)
return _legacy_dispatch(order_line_id)
# Persist a WorkflowRun record from app.domains.rendering.workflow_builder import dispatch_workflow
run = WorkflowRun(
workflow_def_id=wf_def.id, try:
order_line_id=line.id, celery_task_id = dispatch_workflow(workflow_type, order_line_id, params)
celery_task_id=celery_task_id, run.celery_task_id = celery_task_id
status="pending", session.commit()
started_at=datetime.utcnow(), except Exception as exc:
) session.rollback()
session.add(run) session.add(run)
session.commit() mark_workflow_run_failed(run, str(exc))
session.commit()
logger.exception(
"order_line %s: workflow dispatch via definition %s failed, falling back to legacy dispatch",
order_line_id,
wf_def.id,
)
return _legacy_dispatch(order_line_id)
return { return {
"backend": "workflow", "backend": "workflow",
@@ -17,8 +17,13 @@ Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a
warning — this lets you add new StepName values to the enum without breaking warning — this lets you add new StepName values to the enum without breaking
existing dispatchers. existing dispatchers.
""" """
from __future__ import annotations
import logging import logging
import uuid
from collections import deque from collections import deque
from dataclasses import dataclass, field
from typing import Literal
from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode
from app.core.process_steps import StepName from app.core.process_steps import StepName
@@ -26,6 +31,25 @@ from app.core.process_steps import StepName
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
WorkflowExecutionMode = Literal["legacy", "graph", "shadow"]
@dataclass(slots=True)
class WorkflowContext:
context_id: str
execution_mode: WorkflowExecutionMode
workflow_run_id: uuid.UUID | None = None
ordered_nodes: list[WorkflowNode] = field(default_factory=list)
@dataclass(slots=True)
class WorkflowDispatchResult:
context: WorkflowContext
task_ids: list[str]
node_task_ids: dict[str, str]
skipped_node_ids: list[str]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Step → Celery task name mapping # Step → Celery task name mapping
# #
@@ -59,7 +83,76 @@ STEP_TASK_MAP: dict[StepName, str] = {
} }
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]: def prepare_workflow_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowContext:
"""Validate workflow config and build the runtime context."""
config = WorkflowConfig.model_validate(workflow_config)
ordered_nodes = _topological_sort(config)
return WorkflowContext(
context_id=context_id,
execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
ordered_nodes=ordered_nodes,
)
def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult:
"""Execute prepared workflow nodes in topological order as individual Celery tasks."""
from app.tasks.celery_app import celery_app
task_ids: list[str] = []
node_task_ids: dict[str, str] = {}
skipped_node_ids: list[str] = []
for node in context.ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r",
node.step,
context.execution_mode,
node.id,
)
skipped_node_ids.append(node.id)
continue
result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params)
task_ids.append(result.id)
node_task_ids[node.id] = result.id
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
node.id,
node.step,
context.execution_mode,
context.workflow_run_id,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context.context_id,
)
return WorkflowDispatchResult(
context=context,
task_ids=task_ids,
node_task_ids=node_task_ids,
skipped_node_ids=skipped_node_ids,
)
def dispatch_workflow_with_context(
workflow_config: dict,
*,
context_id: str,
execution_mode: WorkflowExecutionMode = "graph",
workflow_run_id: uuid.UUID | None = None,
) -> WorkflowDispatchResult:
"""Execute workflow nodes in topological order as individual Celery tasks. """Execute workflow nodes in topological order as individual Celery tasks.
Args: Args:
@@ -70,47 +163,25 @@ def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
tasks). tasks).
Returns: Returns:
List of dispatched Celery task ID strings (one per mapped node). Dispatch result including runtime context, dispatched task IDs and skipped nodes.
Raises: Raises:
pydantic.ValidationError: if *workflow_config* does not conform to pydantic.ValidationError: if *workflow_config* does not conform to
WorkflowConfig schema. WorkflowConfig schema.
ValueError: if the node graph contains a cycle. ValueError: if the node graph contains a cycle.
""" """
from app.tasks.celery_app import celery_app context = prepare_workflow_context(
workflow_config,
# Validate + parse config context_id=context_id,
config = WorkflowConfig.model_validate(workflow_config) execution_mode=execution_mode,
workflow_run_id=workflow_run_id,
# Topological sort so that dependency nodes dispatch first
ordered_nodes = _topological_sort(config)
task_ids: list[str] = []
for node in ordered_nodes:
task_name = STEP_TASK_MAP.get(node.step)
if task_name is None:
logger.warning(
"[WORKFLOW] No Celery task mapping for step %r — skipping node %r",
node.step,
node.id,
)
continue
result = celery_app.send_task(task_name, args=[context_id], kwargs=node.params)
task_ids.append(result.id)
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s) → Celery task %s",
node.id,
node.step,
result.id,
)
logger.info(
"[WORKFLOW] dispatch_workflow complete: %d task(s) dispatched for context %s",
len(task_ids),
context_id,
) )
return task_ids return dispatch_prepared_workflow(context)
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
"""Backward-compatible wrapper returning only dispatched Celery task IDs."""
return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -200,6 +200,9 @@ async def list_workflow_runs(
class WorkflowDispatchResponse(BaseModel): class WorkflowDispatchResponse(BaseModel):
workflow_run: WorkflowRunOut
context_id: str
execution_mode: str
dispatched: int dispatched: int
task_ids: list[str] task_ids: list[str]
@@ -225,7 +228,15 @@ async def dispatch_workflow_endpoint(
the caller can track progress. the caller can track progress.
""" """
from pydantic import ValidationError as _ValidationError from pydantic import ValidationError as _ValidationError
from app.domains.rendering.workflow_executor import dispatch_workflow from app.domains.rendering.workflow_executor import (
dispatch_prepared_workflow,
prepare_workflow_context,
)
from app.domains.rendering.workflow_run_service import (
apply_graph_dispatch_result,
create_workflow_run,
mark_workflow_run_failed,
)
result = await db.execute( result = await db.execute(
select(WorkflowDefinition).where(WorkflowDefinition.id == workflow_id) select(WorkflowDefinition).where(WorkflowDefinition.id == workflow_id)
@@ -237,10 +248,59 @@ async def dispatch_workflow_endpoint(
raise HTTPException(status_code=400, detail="Workflow has no config") raise HTTPException(status_code=400, detail="Workflow has no config")
try: try:
task_ids = dispatch_workflow(wf.config, context_id) workflow_context = prepare_workflow_context(
wf.config,
context_id=context_id,
execution_mode="graph",
)
except _ValidationError as exc: except _ValidationError as exc:
raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}") raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}")
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) raise HTTPException(status_code=422, detail=str(exc))
return WorkflowDispatchResponse(dispatched=len(task_ids), task_ids=task_ids) run_id = await db.run_sync(
lambda sync_session: create_workflow_run(
sync_session,
workflow_def_id=wf.id,
order_line_id=None,
workflow_context=workflow_context,
).id
)
await db.commit()
try:
dispatch_result = dispatch_prepared_workflow(workflow_context)
except Exception as exc:
failed_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id)
.options(selectinload(WorkflowRun.node_results))
)
failed_run = failed_result.scalar_one()
mark_workflow_run_failed(failed_run, str(exc))
await db.commit()
raise
run_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id)
.options(selectinload(WorkflowRun.node_results))
)
run = run_result.scalar_one()
apply_graph_dispatch_result(run, workflow_context, dispatch_result)
await db.commit()
refreshed_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id)
.options(selectinload(WorkflowRun.node_results))
)
refreshed_run = refreshed_result.scalar_one()
return WorkflowDispatchResponse(
workflow_run=refreshed_run,
context_id=context_id,
execution_mode=workflow_context.execution_mode,
dispatched=len(dispatch_result.task_ids),
task_ids=dispatch_result.task_ids,
)
@@ -0,0 +1,81 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy.orm import Session
from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun
from app.domains.rendering.workflow_executor import WorkflowContext, WorkflowDispatchResult
def create_workflow_run(
session: Session,
*,
workflow_def_id,
order_line_id,
workflow_context: WorkflowContext,
) -> WorkflowRun:
run = WorkflowRun(
workflow_def_id=workflow_def_id,
order_line_id=order_line_id,
status="pending",
started_at=datetime.utcnow(),
)
session.add(run)
session.flush()
workflow_context.workflow_run_id = run.id
for node in workflow_context.ordered_nodes:
metadata = {"step": node.step.value}
if node.ui and node.ui.label:
metadata["label"] = node.ui.label
session.add(
WorkflowNodeResult(
run_id=run.id,
node_name=node.id,
status="pending",
output=metadata,
)
)
session.flush()
return run
def apply_graph_dispatch_result(
run: WorkflowRun,
workflow_context: WorkflowContext,
dispatch_result: WorkflowDispatchResult,
) -> None:
node_map = {node.id: node for node in workflow_context.ordered_nodes}
results_by_name = {node_result.node_name: node_result for node_result in run.node_results}
run.celery_task_id = dispatch_result.task_ids[0] if dispatch_result.task_ids else None
for node_id, node_result in results_by_name.items():
node = node_map.get(node_id)
if node is None:
continue
metadata = dict(node_result.output or {})
metadata.setdefault("step", node.step.value)
if node.ui and node.ui.label:
metadata.setdefault("label", node.ui.label)
task_id = dispatch_result.node_task_ids.get(node_id)
if task_id is not None:
node_result.status = "queued"
metadata["task_id"] = task_id
node_result.output = metadata
continue
if node_id in dispatch_result.skipped_node_ids:
node_result.status = "skipped"
node_result.output = metadata
node_result.log = f"No Celery task mapping for step '{node.step.value}'"
def mark_workflow_run_failed(run: WorkflowRun, error_message: str) -> None:
run.status = "failed"
run.completed_at = datetime.utcnow()
run.error_message = error_message[:2000]
@@ -0,0 +1,240 @@
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.config import settings
from app.domains.orders.models import Order, OrderLine
from app.domains.products.models import Product
from app.domains.rendering.dispatch_service import dispatch_render_with_workflow
from app.domains.rendering.models import OutputType, WorkflowDefinition, WorkflowRun
from app.domains.rendering.workflow_config_utils import build_preset_workflow_config
def _use_test_database(monkeypatch) -> None:
monkeypatch.setattr(settings, "postgres_host", "postgres")
monkeypatch.setattr(settings, "postgres_port", 5432)
monkeypatch.setattr(settings, "postgres_user", "hartomat")
monkeypatch.setattr(settings, "postgres_password", "hartomat")
monkeypatch.setattr(settings, "postgres_db", "hartomat_test")
async def _seed_order_line(
db,
admin_user,
*,
workflow_config: dict | None = None,
) -> dict[str, object]:
product = Product(
pim_id=f"PIM-{uuid.uuid4().hex[:8]}",
name="Workflow Test Product",
)
output_type = OutputType(
name=f"Workflow Output {uuid.uuid4().hex[:8]}",
render_backend="auto",
)
order = Order(
order_number=f"WF-{uuid.uuid4().hex[:10]}",
created_by=admin_user.id,
)
db.add_all([product, output_type, order])
await db.flush()
workflow_definition = None
if workflow_config is not None:
workflow_definition = WorkflowDefinition(
name=f"Workflow {uuid.uuid4().hex[:8]}",
output_type_id=output_type.id,
config=workflow_config,
is_active=True,
)
db.add(workflow_definition)
await db.flush()
output_type.workflow_definition_id = workflow_definition.id
order_line = OrderLine(
order_id=order.id,
product_id=product.id,
output_type_id=output_type.id,
)
db.add(order_line)
await db.commit()
return {
"order_line": order_line,
"workflow_definition": workflow_definition,
"output_type": output_type,
}
@pytest.mark.asyncio
async def test_dispatch_render_with_workflow_falls_back_to_legacy_without_workflow_definition(
db,
admin_user,
monkeypatch,
):
_use_test_database(monkeypatch)
seeded = await _seed_order_line(db, admin_user)
monkeypatch.setattr(
"app.domains.rendering.dispatch_service._legacy_dispatch",
lambda order_line_id: {"backend": "legacy", "order_line_id": order_line_id},
)
result = dispatch_render_with_workflow(str(seeded["order_line"].id))
await db.rollback()
assert result == {
"backend": "legacy",
"order_line_id": str(seeded["order_line"].id),
}
runs = (await db.execute(select(WorkflowRun))).scalars().all()
assert runs == []
@pytest.mark.asyncio
async def test_dispatch_render_with_workflow_creates_run_and_node_results_for_preset_dispatch(
db,
admin_user,
monkeypatch,
):
_use_test_database(monkeypatch)
seeded = await _seed_order_line(
db,
admin_user,
workflow_config=build_preset_workflow_config("still", {"width": 1024, "height": 1024}),
)
monkeypatch.setattr(
"app.domains.rendering.workflow_builder.dispatch_workflow",
lambda workflow_type, order_line_id, params=None: "canvas-123",
)
result = dispatch_render_with_workflow(str(seeded["order_line"].id))
await db.rollback()
run_result = await db.execute(
select(WorkflowRun)
.where(WorkflowRun.id == uuid.UUID(result["workflow_run_id"]))
.options(selectinload(WorkflowRun.node_results))
)
run = run_result.scalar_one()
assert result["backend"] == "workflow"
assert result["workflow_type"] == "still"
assert result["celery_task_id"] == "canvas-123"
assert run.workflow_def_id == seeded["workflow_definition"].id
assert run.order_line_id == seeded["order_line"].id
assert run.celery_task_id == "canvas-123"
assert {node_result.node_name for node_result in run.node_results} == {
"setup",
"template",
"render",
"output",
}
assert all(node_result.status == "pending" for node_result in run.node_results)
@pytest.mark.asyncio
async def test_dispatch_render_with_workflow_falls_back_when_workflow_runtime_preparation_is_invalid(
db,
admin_user,
monkeypatch,
):
_use_test_database(monkeypatch)
seeded = await _seed_order_line(
db,
admin_user,
workflow_config={
"version": 1,
"nodes": [
{"id": "render", "step": "blender_still", "params": {}},
],
"edges": [
{"from": "missing", "to": "render"},
],
},
)
monkeypatch.setattr(
"app.domains.rendering.dispatch_service._legacy_dispatch",
lambda order_line_id: {"backend": "legacy", "order_line_id": order_line_id},
)
result = dispatch_render_with_workflow(str(seeded["order_line"].id))
await db.rollback()
assert result == {
"backend": "legacy",
"order_line_id": str(seeded["order_line"].id),
}
runs = (await db.execute(select(WorkflowRun))).scalars().all()
assert runs == []
@pytest.mark.asyncio
async def test_workflow_dispatch_endpoint_returns_workflow_run_with_node_results(
client,
db,
auth_headers,
monkeypatch,
):
workflow_definition = WorkflowDefinition(
name=f"Dispatch Workflow {uuid.uuid4().hex[:8]}",
config=build_preset_workflow_config("still_with_exports", {"width": 640, "height": 640}),
is_active=True,
)
db.add(workflow_definition)
await db.commit()
await db.refresh(workflow_definition)
calls: list[tuple[str, list[str], dict]] = []
def _fake_send_task(task_name: str, args: list[str], kwargs: dict):
calls.append((task_name, args, kwargs))
return type("Result", (), {"id": f"task-{len(calls)}"})()
context_id = str(uuid.uuid4())
monkeypatch.setattr("app.tasks.celery_app.celery_app.send_task", _fake_send_task)
response = await client.post(
f"/api/workflows/{workflow_definition.id}/dispatch",
params={"context_id": context_id},
headers=auth_headers,
)
assert response.status_code == 200
body = response.json()
assert body["context_id"] == context_id
assert body["execution_mode"] == "graph"
assert body["dispatched"] == 2
assert body["task_ids"] == ["task-1", "task-2"]
assert calls == [
(
"app.domains.rendering.tasks.render_order_line_still_task",
[context_id],
{"width": 640, "height": 640},
),
(
"app.domains.rendering.tasks.export_blend_for_order_line_task",
[context_id],
{},
),
]
node_results = {node["node_name"]: node for node in body["workflow_run"]["node_results"]}
assert body["workflow_run"]["status"] == "pending"
assert body["workflow_run"]["celery_task_id"] == "task-1"
assert node_results["render"]["status"] == "queued"
assert node_results["render"]["output"]["task_id"] == "task-1"
assert node_results["blend"]["status"] == "queued"
assert node_results["blend"]["output"]["task_id"] == "task-2"
assert node_results["setup"]["status"] == "skipped"
assert node_results["template"]["status"] == "skipped"
assert node_results["output"]["status"] == "skipped"
@@ -26,10 +26,11 @@
### Phase 4 ### Phase 4
- [ ] Workflow context introduced - [x] Workflow context introduced
- [ ] Node outputs are persisted and reusable - [ ] Node outputs are persisted and reusable
- [ ] Graph runtime supports legacy fallback - [ ] Graph runtime supports legacy fallback
- [ ] `legacy`, `graph`, and `shadow` modes exist - [ ] `legacy`, `graph`, and `shadow` modes exist
- Progress: Phase 4 foundation now persists `WorkflowRun` and initial `WorkflowNodeResult` records for both linked workflow dispatch and `/api/workflows/{id}/dispatch`, while keeping the legacy preset dispatcher as the safe default fallback.
### Phase 5 ### Phase 5
@@ -70,7 +70,7 @@
### Tickets ### Tickets
- `E4-T1` Introduce `WorkflowContext`. - `E4-T1` Introduce `WorkflowContext`. `completed`
- `E4-T2` Refactor executor to process nodes against context and node outputs. - `E4-T2` Refactor executor to process nodes against context and node outputs.
- `E4-T3` Persist node-level run records, logs, timings, and outputs. - `E4-T3` Persist node-level run records, logs, timings, and outputs.
- `E4-T4` Support retry and failure policies. - `E4-T4` Support retry and failure policies.