feat: execute workflow bridge nodes in graph runtime
This commit is contained in:
@@ -0,0 +1,396 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.core.process_steps import StepName
|
||||
from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun
|
||||
from app.domains.rendering.workflow_executor import STEP_TASK_MAP, WorkflowContext, WorkflowDispatchResult
|
||||
from app.domains.rendering.workflow_node_registry import get_node_definition
|
||||
from app.domains.rendering.workflow_runtime_services import (
|
||||
AutoPopulateMaterialsResult,
|
||||
BBoxResolutionResult,
|
||||
MaterialResolutionResult,
|
||||
OrderLineRenderSetupResult,
|
||||
TemplateResolutionResult,
|
||||
auto_populate_materials_for_cad,
|
||||
prepare_order_line_render_context,
|
||||
resolve_cad_bbox,
|
||||
resolve_order_line_material_map,
|
||||
resolve_order_line_template_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowGraphRuntimeError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WorkflowGraphState:
|
||||
setup: OrderLineRenderSetupResult | None = None
|
||||
template: TemplateResolutionResult | None = None
|
||||
materials: MaterialResolutionResult | None = None
|
||||
auto_populate: AutoPopulateMaterialsResult | None = None
|
||||
bbox: BBoxResolutionResult | None = None
|
||||
node_outputs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
|
||||
_ORDER_LINE_RENDER_STEPS = {
|
||||
StepName.BLENDER_STILL,
|
||||
StepName.BLENDER_TURNTABLE,
|
||||
StepName.EXPORT_BLEND,
|
||||
StepName.OUTPUT_SAVE,
|
||||
StepName.NOTIFY,
|
||||
}
|
||||
|
||||
|
||||
def execute_graph_workflow(
|
||||
session: Session,
|
||||
workflow_context: WorkflowContext,
|
||||
) -> WorkflowDispatchResult:
|
||||
if workflow_context.workflow_run_id is None:
|
||||
raise ValueError("workflow_context.workflow_run_id is required for graph execution")
|
||||
|
||||
run = session.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == workflow_context.workflow_run_id)
|
||||
.options(selectinload(WorkflowRun.node_results))
|
||||
).scalar_one()
|
||||
|
||||
node_results = {node_result.node_name: node_result for node_result in run.node_results}
|
||||
state = WorkflowGraphState()
|
||||
task_ids: list[str] = []
|
||||
node_task_ids: dict[str, str] = {}
|
||||
skipped_node_ids: list[str] = []
|
||||
|
||||
for node in workflow_context.ordered_nodes:
|
||||
node_result = node_results.get(node.id)
|
||||
if node_result is None:
|
||||
logger.warning(
|
||||
"[WORKFLOW] Missing WorkflowNodeResult row for node %s on run %s",
|
||||
node.id,
|
||||
run.id,
|
||||
)
|
||||
continue
|
||||
|
||||
metadata = _base_output(node_result.output, node)
|
||||
definition = get_node_definition(node.step)
|
||||
bridge_executor = _BRIDGE_EXECUTORS.get(node.step)
|
||||
|
||||
if bridge_executor is not None:
|
||||
started = time.perf_counter()
|
||||
node_result.status = "running"
|
||||
node_result.output = dict(metadata)
|
||||
session.flush()
|
||||
try:
|
||||
payload, status, log_message = bridge_executor(
|
||||
session=session,
|
||||
workflow_context=workflow_context,
|
||||
state=state,
|
||||
node_params=node.params,
|
||||
)
|
||||
except Exception as exc:
|
||||
node_result.status = "failed"
|
||||
node_result.log = str(exc)[:2000]
|
||||
node_result.duration_s = round(time.perf_counter() - started, 4)
|
||||
node_result.output = dict(metadata)
|
||||
session.flush()
|
||||
raise WorkflowGraphRuntimeError(
|
||||
f"Node '{node.id}' ({node.step.value}) failed: {exc}"
|
||||
) from exc
|
||||
|
||||
if payload:
|
||||
metadata.update(payload)
|
||||
state.node_outputs[node.id] = payload
|
||||
|
||||
node_result.status = status
|
||||
node_result.log = log_message
|
||||
node_result.output = dict(metadata)
|
||||
node_result.duration_s = round(time.perf_counter() - started, 4)
|
||||
session.flush()
|
||||
|
||||
if status == "failed":
|
||||
raise WorkflowGraphRuntimeError(
|
||||
f"Node '{node.id}' ({node.step.value}) failed: {log_message or 'unknown error'}"
|
||||
)
|
||||
if status == "skipped":
|
||||
skipped_node_ids.append(node.id)
|
||||
continue
|
||||
|
||||
task_name = STEP_TASK_MAP.get(node.step)
|
||||
if task_name is not None:
|
||||
if node.step in _ORDER_LINE_RENDER_STEPS and state.setup is not None and not state.setup.is_ready:
|
||||
metadata["blocked_by"] = "order_line_setup"
|
||||
node_result.status = "skipped"
|
||||
node_result.output = metadata
|
||||
node_result.log = (
|
||||
f"Skipped because order_line_setup did not complete successfully "
|
||||
f"({state.setup.status})"
|
||||
)
|
||||
node_result.duration_s = None
|
||||
session.flush()
|
||||
skipped_node_ids.append(node.id)
|
||||
continue
|
||||
|
||||
from app.tasks.celery_app import celery_app
|
||||
|
||||
result = celery_app.send_task(task_name, args=[workflow_context.context_id], kwargs=node.params)
|
||||
metadata["task_id"] = result.id
|
||||
if definition is not None:
|
||||
metadata["execution_kind"] = definition.execution_kind
|
||||
node_result.status = "queued"
|
||||
node_result.output = metadata
|
||||
node_result.log = None
|
||||
node_result.duration_s = None
|
||||
session.flush()
|
||||
task_ids.append(result.id)
|
||||
node_task_ids[node.id] = result.id
|
||||
logger.info(
|
||||
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
|
||||
node.id,
|
||||
node.step,
|
||||
workflow_context.execution_mode,
|
||||
workflow_context.workflow_run_id,
|
||||
result.id,
|
||||
)
|
||||
continue
|
||||
|
||||
metadata["execution_kind"] = definition.execution_kind if definition is not None else "bridge"
|
||||
node_result.status = "skipped"
|
||||
node_result.output = metadata
|
||||
node_result.log = f"Graph runtime not implemented for step '{node.step.value}'"
|
||||
node_result.duration_s = None
|
||||
session.flush()
|
||||
skipped_node_ids.append(node.id)
|
||||
|
||||
run.celery_task_id = task_ids[0] if task_ids else None
|
||||
if any(node_result.status == "failed" for node_result in run.node_results):
|
||||
run.status = "failed"
|
||||
run.completed_at = datetime.utcnow()
|
||||
elif task_ids:
|
||||
run.status = "pending"
|
||||
run.completed_at = None
|
||||
else:
|
||||
run.status = "completed"
|
||||
run.completed_at = datetime.utcnow()
|
||||
session.flush()
|
||||
|
||||
return WorkflowDispatchResult(
|
||||
context=workflow_context,
|
||||
task_ids=task_ids,
|
||||
node_task_ids=node_task_ids,
|
||||
skipped_node_ids=skipped_node_ids,
|
||||
)
|
||||
|
||||
|
||||
def _base_output(existing: dict[str, Any] | None, node) -> dict[str, Any]:
|
||||
metadata = dict(existing or {})
|
||||
metadata.setdefault("step", node.step.value)
|
||||
if node.ui and node.ui.label:
|
||||
metadata.setdefault("label", node.ui.label)
|
||||
definition = get_node_definition(node.step)
|
||||
if definition is not None:
|
||||
metadata.setdefault("execution_kind", definition.execution_kind)
|
||||
return metadata
|
||||
|
||||
|
||||
def _serialize_setup_result(result: OrderLineRenderSetupResult) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"setup_status": result.status,
|
||||
"reason": result.reason,
|
||||
"materials_source_count": len(result.materials_source or []),
|
||||
"part_colors_count": len(result.part_colors or {}),
|
||||
"usd_render_path": str(result.usd_render_path) if result.usd_render_path else None,
|
||||
"glb_reuse_path": str(result.glb_reuse_path) if result.glb_reuse_path else None,
|
||||
}
|
||||
if result.order_line is not None:
|
||||
payload["order_line_id"] = str(result.order_line.id)
|
||||
payload["product_id"] = str(result.order_line.product_id) if result.order_line.product_id else None
|
||||
payload["output_type_id"] = str(result.order_line.output_type_id) if result.order_line.output_type_id else None
|
||||
if result.order is not None:
|
||||
payload["order_id"] = str(result.order.id)
|
||||
payload["order_status"] = result.order.status.value if getattr(result.order, "status", None) else None
|
||||
if result.cad_file is not None:
|
||||
payload["cad_file_id"] = str(result.cad_file.id)
|
||||
payload["step_path"] = result.cad_file.stored_path
|
||||
return payload
|
||||
|
||||
|
||||
def _serialize_template_result(result: TemplateResolutionResult) -> dict[str, Any]:
|
||||
return {
|
||||
"template_id": str(result.template.id) if result.template is not None else None,
|
||||
"template_name": result.template.name if result.template is not None else None,
|
||||
"template_path": result.template.blend_file_path if result.template is not None else None,
|
||||
"material_library": result.material_library,
|
||||
"material_map": result.material_map,
|
||||
"material_map_count": len(result.material_map or {}),
|
||||
"use_materials": result.use_materials,
|
||||
"override_material": result.override_material,
|
||||
"category_key": result.category_key,
|
||||
"output_type_id": result.output_type_id,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_material_result(result: MaterialResolutionResult) -> dict[str, Any]:
|
||||
return {
|
||||
"material_map": result.material_map,
|
||||
"material_map_count": len(result.material_map or {}),
|
||||
"use_materials": result.use_materials,
|
||||
"override_material": result.override_material,
|
||||
"source_material_count": result.source_material_count,
|
||||
"resolved_material_count": result.resolved_material_count,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_auto_populate_result(result: AutoPopulateMaterialsResult) -> dict[str, Any]:
|
||||
return {
|
||||
"cad_file_id": result.cad_file_id,
|
||||
"updated_product_ids": result.updated_product_ids,
|
||||
"updated_product_count": len(result.updated_product_ids),
|
||||
"queued_thumbnail_regeneration": result.queued_thumbnail_regeneration,
|
||||
"part_colors": result.part_colors,
|
||||
"part_colors_count": len(result.part_colors or {}),
|
||||
"cad_parts": result.cad_parts,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_bbox_result(result: BBoxResolutionResult) -> dict[str, Any]:
|
||||
return {
|
||||
"bbox_data": result.bbox_data,
|
||||
"has_bbox": result.has_bbox,
|
||||
"source_kind": result.source_kind,
|
||||
"step_path": result.step_path,
|
||||
"glb_path": result.glb_path,
|
||||
}
|
||||
|
||||
|
||||
def _execute_order_line_setup(
|
||||
*,
|
||||
session: Session,
|
||||
workflow_context: WorkflowContext,
|
||||
state: WorkflowGraphState,
|
||||
node_params: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], str, str | None]:
|
||||
del node_params
|
||||
setup = prepare_order_line_render_context(session, workflow_context.context_id)
|
||||
state.setup = setup
|
||||
payload = _serialize_setup_result(setup)
|
||||
if setup.status == "ready":
|
||||
return payload, "completed", None
|
||||
if setup.status == "skip":
|
||||
return payload, "skipped", setup.reason
|
||||
return payload, "failed", setup.reason or "order_line_setup_failed"
|
||||
|
||||
|
||||
def _execute_resolve_template(
|
||||
*,
|
||||
session: Session,
|
||||
workflow_context: WorkflowContext,
|
||||
state: WorkflowGraphState,
|
||||
node_params: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], str, str | None]:
|
||||
del workflow_context, node_params
|
||||
if state.setup is None or not state.setup.is_ready:
|
||||
if state.setup is not None and state.setup.status == "skip":
|
||||
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
|
||||
raise WorkflowGraphRuntimeError("resolve_template requires a ready order_line_setup result")
|
||||
result = resolve_order_line_template_context(session, state.setup)
|
||||
state.template = result
|
||||
return _serialize_template_result(result), "completed", None
|
||||
|
||||
|
||||
def _execute_material_map_resolve(
|
||||
*,
|
||||
session: Session,
|
||||
workflow_context: WorkflowContext,
|
||||
state: WorkflowGraphState,
|
||||
node_params: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], str, str | None]:
|
||||
del session, workflow_context, node_params
|
||||
if state.setup is None or not state.setup.is_ready:
|
||||
if state.setup is not None and state.setup.status == "skip":
|
||||
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
|
||||
raise WorkflowGraphRuntimeError("material_map_resolve requires a ready order_line_setup result")
|
||||
|
||||
line = state.setup.order_line
|
||||
cad_file = state.setup.cad_file
|
||||
if line is None:
|
||||
raise WorkflowGraphRuntimeError("material_map_resolve requires an order line")
|
||||
|
||||
material_library = state.template.material_library if state.template is not None else None
|
||||
template = state.template.template if state.template is not None else None
|
||||
result = resolve_order_line_material_map(
|
||||
line,
|
||||
cad_file,
|
||||
state.setup.materials_source,
|
||||
material_library=material_library,
|
||||
template=template,
|
||||
)
|
||||
state.materials = result
|
||||
return _serialize_material_result(result), "completed", None
|
||||
|
||||
|
||||
def _execute_auto_populate_materials(
|
||||
*,
|
||||
session: Session,
|
||||
workflow_context: WorkflowContext,
|
||||
state: WorkflowGraphState,
|
||||
node_params: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], str, str | None]:
|
||||
del workflow_context, node_params
|
||||
if state.setup is None or state.setup.cad_file is None:
|
||||
if state.setup is not None and state.setup.status == "skip":
|
||||
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
|
||||
raise WorkflowGraphRuntimeError("auto_populate_materials requires a resolved cad_file")
|
||||
result = auto_populate_materials_for_cad(session, str(state.setup.cad_file.id))
|
||||
state.auto_populate = result
|
||||
if state.setup.order_line is not None and state.setup.order_line.product is not None:
|
||||
session.refresh(state.setup.order_line.product)
|
||||
state.setup.materials_source = state.setup.order_line.product.cad_part_materials or []
|
||||
return _serialize_auto_populate_result(result), "completed", None
|
||||
|
||||
|
||||
def _execute_glb_bbox(
|
||||
*,
|
||||
session: Session,
|
||||
workflow_context: WorkflowContext,
|
||||
state: WorkflowGraphState,
|
||||
node_params: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], str, str | None]:
|
||||
del session, workflow_context
|
||||
if state.setup is None or state.setup.cad_file is None:
|
||||
if state.setup is not None and state.setup.status == "skip":
|
||||
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
|
||||
raise WorkflowGraphRuntimeError("glb_bbox requires a resolved cad_file")
|
||||
|
||||
step_path = state.setup.cad_file.stored_path
|
||||
glb_path = node_params.get("glb_path")
|
||||
if glb_path is None and state.setup.glb_reuse_path is not None:
|
||||
glb_path = str(state.setup.glb_reuse_path)
|
||||
elif glb_path is None:
|
||||
step_file = Path(step_path)
|
||||
fallback_glb = step_file.parent / f"{step_file.stem}_thumbnail.glb"
|
||||
if fallback_glb.exists():
|
||||
glb_path = str(fallback_glb)
|
||||
|
||||
result = resolve_cad_bbox(step_path, glb_path=glb_path)
|
||||
state.bbox = result
|
||||
return _serialize_bbox_result(result), "completed", None
|
||||
|
||||
|
||||
_BRIDGE_EXECUTORS = {
|
||||
StepName.ORDER_LINE_SETUP: _execute_order_line_setup,
|
||||
StepName.RESOLVE_TEMPLATE: _execute_resolve_template,
|
||||
StepName.MATERIAL_MAP_RESOLVE: _execute_material_map_resolve,
|
||||
StepName.AUTO_POPULATE_MATERIALS: _execute_auto_populate_materials,
|
||||
StepName.GLB_BBOX: _execute_glb_bbox,
|
||||
}
|
||||
Reference in New Issue
Block a user