feat: add workflow run dispatch foundation
This commit is contained in:
@@ -13,7 +13,6 @@ task-dispatch logic in app.services.render_dispatcher (legacy path).
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,7 +31,9 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
|
||||
|
||||
from app.config import settings
|
||||
from app.domains.orders.models import OrderLine
|
||||
from app.domains.rendering.models import OutputType, WorkflowDefinition, WorkflowRun
|
||||
from app.domains.rendering.models import OutputType, WorkflowDefinition
|
||||
from app.domains.rendering.workflow_executor import prepare_workflow_context
|
||||
from app.domains.rendering.workflow_run_service import create_workflow_run, mark_workflow_run_failed
|
||||
|
||||
engine = create_engine(
|
||||
settings.database_url.replace("+asyncpg", ""),
|
||||
@@ -96,6 +97,22 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
|
||||
workflow_type,
|
||||
)
|
||||
|
||||
try:
|
||||
workflow_context = prepare_workflow_context(
|
||||
wf_def.config,
|
||||
context_id=order_line_id,
|
||||
execution_mode="legacy",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"order_line %s: workflow_definition_id %s failed runtime preparation (%s), "
|
||||
"falling back to legacy dispatch",
|
||||
order_line_id,
|
||||
wf_def.id,
|
||||
exc,
|
||||
)
|
||||
return _legacy_dispatch(order_line_id)
|
||||
|
||||
# For turntable workflows: resolve step_path + output_dir from the order line at runtime
|
||||
if workflow_type == "turntable" and ("step_path" not in params or "output_dir" not in params):
|
||||
from app.domains.products.models import CadFile as _CadFile
|
||||
@@ -120,19 +137,43 @@ def dispatch_render_with_workflow(order_line_id: str) -> dict:
|
||||
str(_Path(_cfg.upload_dir) / "renders" / str(line.id)),
|
||||
)
|
||||
|
||||
from app.domains.rendering.workflow_builder import dispatch_workflow
|
||||
celery_task_id = dispatch_workflow(workflow_type, order_line_id, params)
|
||||
run = None
|
||||
try:
|
||||
run = create_workflow_run(
|
||||
session,
|
||||
workflow_def_id=wf_def.id,
|
||||
order_line_id=line.id,
|
||||
workflow_context=workflow_context,
|
||||
)
|
||||
session.commit()
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
logger.warning(
|
||||
"order_line %s: failed to create workflow run for workflow_definition_id %s (%s), "
|
||||
"falling back to legacy dispatch",
|
||||
order_line_id,
|
||||
wf_def.id,
|
||||
exc,
|
||||
)
|
||||
return _legacy_dispatch(order_line_id)
|
||||
|
||||
# Persist a WorkflowRun record
|
||||
run = WorkflowRun(
|
||||
workflow_def_id=wf_def.id,
|
||||
order_line_id=line.id,
|
||||
celery_task_id=celery_task_id,
|
||||
status="pending",
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(run)
|
||||
session.commit()
|
||||
from app.domains.rendering.workflow_builder import dispatch_workflow
|
||||
|
||||
try:
|
||||
celery_task_id = dispatch_workflow(workflow_type, order_line_id, params)
|
||||
run.celery_task_id = celery_task_id
|
||||
session.commit()
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
session.add(run)
|
||||
mark_workflow_run_failed(run, str(exc))
|
||||
session.commit()
|
||||
logger.exception(
|
||||
"order_line %s: workflow dispatch via definition %s failed, falling back to legacy dispatch",
|
||||
order_line_id,
|
||||
wf_def.id,
|
||||
)
|
||||
return _legacy_dispatch(order_line_id)
|
||||
|
||||
return {
|
||||
"backend": "workflow",
|
||||
|
||||
@@ -17,8 +17,13 @@ Nodes whose StepName has no mapping in STEP_TASK_MAP are skipped with a
|
||||
warning — this lets you add new StepName values to the enum without breaking
|
||||
existing dispatchers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from app.domains.rendering.workflow_schema import WorkflowConfig, WorkflowNode
|
||||
from app.core.process_steps import StepName
|
||||
@@ -26,6 +31,25 @@ from app.core.process_steps import StepName
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WorkflowExecutionMode = Literal["legacy", "graph", "shadow"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WorkflowContext:
|
||||
context_id: str
|
||||
execution_mode: WorkflowExecutionMode
|
||||
workflow_run_id: uuid.UUID | None = None
|
||||
ordered_nodes: list[WorkflowNode] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WorkflowDispatchResult:
|
||||
context: WorkflowContext
|
||||
task_ids: list[str]
|
||||
node_task_ids: dict[str, str]
|
||||
skipped_node_ids: list[str]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step → Celery task name mapping
|
||||
#
|
||||
@@ -59,7 +83,76 @@ STEP_TASK_MAP: dict[StepName, str] = {
|
||||
}
|
||||
|
||||
|
||||
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
|
||||
def prepare_workflow_context(
|
||||
workflow_config: dict,
|
||||
*,
|
||||
context_id: str,
|
||||
execution_mode: WorkflowExecutionMode = "graph",
|
||||
workflow_run_id: uuid.UUID | None = None,
|
||||
) -> WorkflowContext:
|
||||
"""Validate workflow config and build the runtime context."""
|
||||
config = WorkflowConfig.model_validate(workflow_config)
|
||||
ordered_nodes = _topological_sort(config)
|
||||
return WorkflowContext(
|
||||
context_id=context_id,
|
||||
execution_mode=execution_mode,
|
||||
workflow_run_id=workflow_run_id,
|
||||
ordered_nodes=ordered_nodes,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_prepared_workflow(context: WorkflowContext) -> WorkflowDispatchResult:
|
||||
"""Execute prepared workflow nodes in topological order as individual Celery tasks."""
|
||||
from app.tasks.celery_app import celery_app
|
||||
|
||||
task_ids: list[str] = []
|
||||
node_task_ids: dict[str, str] = {}
|
||||
skipped_node_ids: list[str] = []
|
||||
|
||||
for node in context.ordered_nodes:
|
||||
task_name = STEP_TASK_MAP.get(node.step)
|
||||
if task_name is None:
|
||||
logger.warning(
|
||||
"[WORKFLOW] No Celery task mapping for step %r in mode %s - skipping node %r",
|
||||
node.step,
|
||||
context.execution_mode,
|
||||
node.id,
|
||||
)
|
||||
skipped_node_ids.append(node.id)
|
||||
continue
|
||||
|
||||
result = celery_app.send_task(task_name, args=[context.context_id], kwargs=node.params)
|
||||
task_ids.append(result.id)
|
||||
node_task_ids[node.id] = result.id
|
||||
logger.info(
|
||||
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
|
||||
node.id,
|
||||
node.step,
|
||||
context.execution_mode,
|
||||
context.workflow_run_id,
|
||||
result.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[WORKFLOW] dispatch_prepared_workflow complete: %d task(s) dispatched for context %s",
|
||||
len(task_ids),
|
||||
context.context_id,
|
||||
)
|
||||
return WorkflowDispatchResult(
|
||||
context=context,
|
||||
task_ids=task_ids,
|
||||
node_task_ids=node_task_ids,
|
||||
skipped_node_ids=skipped_node_ids,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_workflow_with_context(
|
||||
workflow_config: dict,
|
||||
*,
|
||||
context_id: str,
|
||||
execution_mode: WorkflowExecutionMode = "graph",
|
||||
workflow_run_id: uuid.UUID | None = None,
|
||||
) -> WorkflowDispatchResult:
|
||||
"""Execute workflow nodes in topological order as individual Celery tasks.
|
||||
|
||||
Args:
|
||||
@@ -70,47 +163,25 @@ def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
|
||||
tasks).
|
||||
|
||||
Returns:
|
||||
List of dispatched Celery task ID strings (one per mapped node).
|
||||
Dispatch result including runtime context, dispatched task IDs and skipped nodes.
|
||||
|
||||
Raises:
|
||||
pydantic.ValidationError: if *workflow_config* does not conform to
|
||||
WorkflowConfig schema.
|
||||
ValueError: if the node graph contains a cycle.
|
||||
"""
|
||||
from app.tasks.celery_app import celery_app
|
||||
|
||||
# Validate + parse config
|
||||
config = WorkflowConfig.model_validate(workflow_config)
|
||||
|
||||
# Topological sort so that dependency nodes dispatch first
|
||||
ordered_nodes = _topological_sort(config)
|
||||
|
||||
task_ids: list[str] = []
|
||||
for node in ordered_nodes:
|
||||
task_name = STEP_TASK_MAP.get(node.step)
|
||||
if task_name is None:
|
||||
logger.warning(
|
||||
"[WORKFLOW] No Celery task mapping for step %r — skipping node %r",
|
||||
node.step,
|
||||
node.id,
|
||||
)
|
||||
continue
|
||||
|
||||
result = celery_app.send_task(task_name, args=[context_id], kwargs=node.params)
|
||||
task_ids.append(result.id)
|
||||
logger.info(
|
||||
"[WORKFLOW] Dispatched node %r (step=%s) → Celery task %s",
|
||||
node.id,
|
||||
node.step,
|
||||
result.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[WORKFLOW] dispatch_workflow complete: %d task(s) dispatched for context %s",
|
||||
len(task_ids),
|
||||
context_id,
|
||||
context = prepare_workflow_context(
|
||||
workflow_config,
|
||||
context_id=context_id,
|
||||
execution_mode=execution_mode,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return task_ids
|
||||
return dispatch_prepared_workflow(context)
|
||||
|
||||
|
||||
def dispatch_workflow(workflow_config: dict, context_id: str) -> list[str]:
|
||||
"""Backward-compatible wrapper returning only dispatched Celery task IDs."""
|
||||
return dispatch_workflow_with_context(workflow_config, context_id=context_id).task_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -200,6 +200,9 @@ async def list_workflow_runs(
|
||||
|
||||
|
||||
class WorkflowDispatchResponse(BaseModel):
|
||||
workflow_run: WorkflowRunOut
|
||||
context_id: str
|
||||
execution_mode: str
|
||||
dispatched: int
|
||||
task_ids: list[str]
|
||||
|
||||
@@ -225,7 +228,15 @@ async def dispatch_workflow_endpoint(
|
||||
the caller can track progress.
|
||||
"""
|
||||
from pydantic import ValidationError as _ValidationError
|
||||
from app.domains.rendering.workflow_executor import dispatch_workflow
|
||||
from app.domains.rendering.workflow_executor import (
|
||||
dispatch_prepared_workflow,
|
||||
prepare_workflow_context,
|
||||
)
|
||||
from app.domains.rendering.workflow_run_service import (
|
||||
apply_graph_dispatch_result,
|
||||
create_workflow_run,
|
||||
mark_workflow_run_failed,
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(WorkflowDefinition).where(WorkflowDefinition.id == workflow_id)
|
||||
@@ -237,10 +248,59 @@ async def dispatch_workflow_endpoint(
|
||||
raise HTTPException(status_code=400, detail="Workflow has no config")
|
||||
|
||||
try:
|
||||
task_ids = dispatch_workflow(wf.config, context_id)
|
||||
workflow_context = prepare_workflow_context(
|
||||
wf.config,
|
||||
context_id=context_id,
|
||||
execution_mode="graph",
|
||||
)
|
||||
except _ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid workflow config: {exc.errors()}")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
|
||||
return WorkflowDispatchResponse(dispatched=len(task_ids), task_ids=task_ids)
|
||||
run_id = await db.run_sync(
|
||||
lambda sync_session: create_workflow_run(
|
||||
sync_session,
|
||||
workflow_def_id=wf.id,
|
||||
order_line_id=None,
|
||||
workflow_context=workflow_context,
|
||||
).id
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
dispatch_result = dispatch_prepared_workflow(workflow_context)
|
||||
except Exception as exc:
|
||||
failed_result = await db.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id)
|
||||
.options(selectinload(WorkflowRun.node_results))
|
||||
)
|
||||
failed_run = failed_result.scalar_one()
|
||||
mark_workflow_run_failed(failed_run, str(exc))
|
||||
await db.commit()
|
||||
raise
|
||||
|
||||
run_result = await db.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id)
|
||||
.options(selectinload(WorkflowRun.node_results))
|
||||
)
|
||||
run = run_result.scalar_one()
|
||||
apply_graph_dispatch_result(run, workflow_context, dispatch_result)
|
||||
await db.commit()
|
||||
|
||||
refreshed_result = await db.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id)
|
||||
.options(selectinload(WorkflowRun.node_results))
|
||||
)
|
||||
refreshed_run = refreshed_result.scalar_one()
|
||||
|
||||
return WorkflowDispatchResponse(
|
||||
workflow_run=refreshed_run,
|
||||
context_id=context_id,
|
||||
execution_mode=workflow_context.execution_mode,
|
||||
dispatched=len(dispatch_result.task_ids),
|
||||
task_ids=dispatch_result.task_ids,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun
|
||||
from app.domains.rendering.workflow_executor import WorkflowContext, WorkflowDispatchResult
|
||||
|
||||
|
||||
def create_workflow_run(
|
||||
session: Session,
|
||||
*,
|
||||
workflow_def_id,
|
||||
order_line_id,
|
||||
workflow_context: WorkflowContext,
|
||||
) -> WorkflowRun:
|
||||
run = WorkflowRun(
|
||||
workflow_def_id=workflow_def_id,
|
||||
order_line_id=order_line_id,
|
||||
status="pending",
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(run)
|
||||
session.flush()
|
||||
|
||||
workflow_context.workflow_run_id = run.id
|
||||
for node in workflow_context.ordered_nodes:
|
||||
metadata = {"step": node.step.value}
|
||||
if node.ui and node.ui.label:
|
||||
metadata["label"] = node.ui.label
|
||||
session.add(
|
||||
WorkflowNodeResult(
|
||||
run_id=run.id,
|
||||
node_name=node.id,
|
||||
status="pending",
|
||||
output=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
session.flush()
|
||||
return run
|
||||
|
||||
|
||||
def apply_graph_dispatch_result(
|
||||
run: WorkflowRun,
|
||||
workflow_context: WorkflowContext,
|
||||
dispatch_result: WorkflowDispatchResult,
|
||||
) -> None:
|
||||
node_map = {node.id: node for node in workflow_context.ordered_nodes}
|
||||
results_by_name = {node_result.node_name: node_result for node_result in run.node_results}
|
||||
|
||||
run.celery_task_id = dispatch_result.task_ids[0] if dispatch_result.task_ids else None
|
||||
|
||||
for node_id, node_result in results_by_name.items():
|
||||
node = node_map.get(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
|
||||
metadata = dict(node_result.output or {})
|
||||
metadata.setdefault("step", node.step.value)
|
||||
if node.ui and node.ui.label:
|
||||
metadata.setdefault("label", node.ui.label)
|
||||
|
||||
task_id = dispatch_result.node_task_ids.get(node_id)
|
||||
if task_id is not None:
|
||||
node_result.status = "queued"
|
||||
metadata["task_id"] = task_id
|
||||
node_result.output = metadata
|
||||
continue
|
||||
|
||||
if node_id in dispatch_result.skipped_node_ids:
|
||||
node_result.status = "skipped"
|
||||
node_result.output = metadata
|
||||
node_result.log = f"No Celery task mapping for step '{node.step.value}'"
|
||||
|
||||
|
||||
def mark_workflow_run_failed(run: WorkflowRun, error_message: str) -> None:
|
||||
run.status = "failed"
|
||||
run.completed_at = datetime.utcnow()
|
||||
run.error_message = error_message[:2000]
|
||||
Reference in New Issue
Block a user