from __future__ import annotations import logging import time from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from app.core.process_steps import StepName from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun from app.domains.rendering.workflow_executor import STEP_TASK_MAP, WorkflowContext, WorkflowDispatchResult from app.domains.rendering.workflow_node_registry import get_node_definition from app.domains.rendering.workflow_runtime_services import ( AutoPopulateMaterialsResult, BBoxResolutionResult, MaterialResolutionResult, OrderLineRenderSetupResult, TemplateResolutionResult, auto_populate_materials_for_cad, prepare_order_line_render_context, resolve_cad_bbox, resolve_order_line_material_map, resolve_order_line_template_context, ) logger = logging.getLogger(__name__) class WorkflowGraphRuntimeError(RuntimeError): pass @dataclass(slots=True) class WorkflowGraphState: setup: OrderLineRenderSetupResult | None = None template: TemplateResolutionResult | None = None materials: MaterialResolutionResult | None = None auto_populate: AutoPopulateMaterialsResult | None = None bbox: BBoxResolutionResult | None = None node_outputs: dict[str, dict[str, Any]] = field(default_factory=dict) _ORDER_LINE_RENDER_STEPS = { StepName.BLENDER_STILL, StepName.BLENDER_TURNTABLE, StepName.EXPORT_BLEND, StepName.OUTPUT_SAVE, StepName.NOTIFY, } def find_unsupported_graph_nodes(workflow_context: WorkflowContext) -> list[str]: unsupported: list[str] = [] for node in workflow_context.ordered_nodes: if node.step in _BRIDGE_EXECUTORS: continue if STEP_TASK_MAP.get(node.step) is not None: continue unsupported.append(node.id) return unsupported def execute_graph_workflow( session: Session, workflow_context: WorkflowContext, ) -> WorkflowDispatchResult: if workflow_context.workflow_run_id is None: raise ValueError("workflow_context.workflow_run_id is required for graph execution") run = session.execute( select(WorkflowRun) .where(WorkflowRun.id == workflow_context.workflow_run_id) .options(selectinload(WorkflowRun.node_results)) ).scalar_one() node_results = {node_result.node_name: node_result for node_result in run.node_results} state = WorkflowGraphState() task_ids: list[str] = [] node_task_ids: dict[str, str] = {} skipped_node_ids: list[str] = [] for node in workflow_context.ordered_nodes: node_result = node_results.get(node.id) if node_result is None: logger.warning( "[WORKFLOW] Missing WorkflowNodeResult row for node %s on run %s", node.id, run.id, ) continue retry_policy = _retry_policy(node.params) failure_policy = _failure_policy(node.params) metadata = _base_output(node_result.output, node) metadata["retry_policy"] = retry_policy metadata["failure_policy"] = failure_policy definition = get_node_definition(node.step) bridge_executor = _BRIDGE_EXECUTORS.get(node.step) if bridge_executor is not None: max_attempts = retry_policy["max_attempts"] last_error: str | None = None for attempt in range(1, max_attempts + 1): started = time.perf_counter() attempt_output = dict(metadata) attempt_output["attempt_count"] = attempt attempt_output["max_attempts"] = max_attempts node_result.status = "running" node_result.output = attempt_output session.flush() try: payload, status, log_message = bridge_executor( session=session, workflow_context=workflow_context, state=state, node_params=node.params, ) except Exception as exc: last_error = str(exc)[:2000] if attempt < max_attempts: retry_output = dict(attempt_output) retry_output["last_error"] = last_error retry_output["retry_state"] = "retrying" node_result.status = "retrying" node_result.log = f"Attempt {attempt}/{max_attempts} failed: {last_error}" node_result.output = retry_output node_result.duration_s = round(time.perf_counter() - started, 4) session.flush() continue failed_output = dict(attempt_output) failed_output["last_error"] = last_error failed_output["retry_exhausted"] = True node_result.status = "failed" node_result.log = last_error node_result.duration_s = round(time.perf_counter() - started, 4) node_result.output = failed_output session.flush() raise WorkflowGraphRuntimeError( f"Node '{node.id}' ({node.step.value}) failed: {exc}" ) from exc if payload: metadata.update(payload) state.node_outputs[node.id] = payload final_output = dict(metadata) final_output["attempt_count"] = attempt final_output["max_attempts"] = max_attempts if last_error is not None: final_output["last_error"] = last_error final_output["retry_state"] = "recovered" node_result.status = status node_result.log = log_message node_result.output = final_output node_result.duration_s = round(time.perf_counter() - started, 4) session.flush() if status == "failed": last_error = (log_message or "unknown error")[:2000] if attempt < max_attempts: retry_output = dict(final_output) retry_output["last_error"] = last_error retry_output["retry_state"] = "retrying" node_result.status = "retrying" node_result.log = f"Attempt {attempt}/{max_attempts} failed: {last_error}" node_result.output = retry_output session.flush() continue failed_output = dict(final_output) failed_output["last_error"] = last_error failed_output["retry_exhausted"] = True node_result.status = "failed" node_result.log = last_error node_result.output = failed_output session.flush() raise WorkflowGraphRuntimeError( f"Node '{node.id}' ({node.step.value}) failed: {last_error}" ) if status == "skipped": skipped_node_ids.append(node.id) break continue task_name = STEP_TASK_MAP.get(node.step) if task_name is not None: if node.step in _ORDER_LINE_RENDER_STEPS and state.setup is not None and not state.setup.is_ready: metadata["blocked_by"] = "order_line_setup" node_result.status = "skipped" node_result.output = metadata node_result.log = ( f"Skipped because order_line_setup did not complete successfully " f"({state.setup.status})" ) node_result.duration_s = None session.flush() skipped_node_ids.append(node.id) continue from app.tasks.celery_app import celery_app result = celery_app.send_task(task_name, args=[workflow_context.context_id], kwargs=node.params) metadata["task_id"] = result.id if definition is not None: metadata["execution_kind"] = definition.execution_kind metadata["attempt_count"] = 1 metadata["max_attempts"] = retry_policy["max_attempts"] node_result.status = "queued" node_result.output = metadata node_result.log = None node_result.duration_s = None session.flush() task_ids.append(result.id) node_task_ids[node.id] = result.id logger.info( "[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s", node.id, node.step, workflow_context.execution_mode, workflow_context.workflow_run_id, result.id, ) continue metadata["execution_kind"] = definition.execution_kind if definition is not None else "bridge" node_result.status = "skipped" node_result.output = metadata node_result.log = f"Graph runtime not implemented for step '{node.step.value}'" node_result.duration_s = None session.flush() skipped_node_ids.append(node.id) run.celery_task_id = task_ids[0] if task_ids else None if any(node_result.status == "failed" for node_result in run.node_results): run.status = "failed" run.completed_at = datetime.utcnow() elif task_ids: run.status = "pending" run.completed_at = None else: run.status = "completed" run.completed_at = datetime.utcnow() session.flush() return WorkflowDispatchResult( context=workflow_context, task_ids=task_ids, node_task_ids=node_task_ids, skipped_node_ids=skipped_node_ids, ) def _base_output(existing: dict[str, Any] | None, node) -> dict[str, Any]: metadata = dict(existing or {}) metadata.setdefault("step", node.step.value) if node.ui and node.ui.label: metadata.setdefault("label", node.ui.label) definition = get_node_definition(node.step) if definition is not None: metadata.setdefault("execution_kind", definition.execution_kind) return metadata def _retry_policy(node_params: dict[str, Any]) -> dict[str, Any]: raw = node_params.get("retry_policy") if not isinstance(raw, dict): raw = {} try: max_attempts = int(raw.get("max_attempts", 1)) except (TypeError, ValueError): max_attempts = 1 return { "max_attempts": max(1, min(max_attempts, 5)), } def _failure_policy(node_params: dict[str, Any]) -> dict[str, Any]: raw = node_params.get("failure_policy") if not isinstance(raw, dict): raw = {} return { "halt_workflow": bool(raw.get("halt_workflow", True)), "fallback_to_legacy": bool(raw.get("fallback_to_legacy", False)), } def _serialize_setup_result(result: OrderLineRenderSetupResult) -> dict[str, Any]: payload: dict[str, Any] = { "setup_status": result.status, "reason": result.reason, "materials_source_count": len(result.materials_source or []), "part_colors_count": len(result.part_colors or {}), "usd_render_path": str(result.usd_render_path) if result.usd_render_path else None, "glb_reuse_path": str(result.glb_reuse_path) if result.glb_reuse_path else None, } if result.order_line is not None: payload["order_line_id"] = str(result.order_line.id) payload["product_id"] = str(result.order_line.product_id) if result.order_line.product_id else None payload["output_type_id"] = str(result.order_line.output_type_id) if result.order_line.output_type_id else None if result.order is not None: payload["order_id"] = str(result.order.id) payload["order_status"] = result.order.status.value if getattr(result.order, "status", None) else None if result.cad_file is not None: payload["cad_file_id"] = str(result.cad_file.id) payload["step_path"] = result.cad_file.stored_path return payload def _serialize_template_result(result: TemplateResolutionResult) -> dict[str, Any]: return { "template_id": str(result.template.id) if result.template is not None else None, "template_name": result.template.name if result.template is not None else None, "template_path": result.template.blend_file_path if result.template is not None else None, "material_library": result.material_library, "material_map": result.material_map, "material_map_count": len(result.material_map or {}), "use_materials": result.use_materials, "override_material": result.override_material, "category_key": result.category_key, "output_type_id": result.output_type_id, } def _serialize_material_result(result: MaterialResolutionResult) -> dict[str, Any]: return { "material_map": result.material_map, "material_map_count": len(result.material_map or {}), "use_materials": result.use_materials, "override_material": result.override_material, "source_material_count": result.source_material_count, "resolved_material_count": result.resolved_material_count, } def _serialize_auto_populate_result(result: AutoPopulateMaterialsResult) -> dict[str, Any]: return { "cad_file_id": result.cad_file_id, "updated_product_ids": result.updated_product_ids, "updated_product_count": len(result.updated_product_ids), "queued_thumbnail_regeneration": result.queued_thumbnail_regeneration, "part_colors": result.part_colors, "part_colors_count": len(result.part_colors or {}), "cad_parts": result.cad_parts, } def _serialize_bbox_result(result: BBoxResolutionResult) -> dict[str, Any]: return { "bbox_data": result.bbox_data, "has_bbox": result.has_bbox, "source_kind": result.source_kind, "step_path": result.step_path, "glb_path": result.glb_path, } def _execute_order_line_setup( *, session: Session, workflow_context: WorkflowContext, state: WorkflowGraphState, node_params: dict[str, Any], ) -> tuple[dict[str, Any], str, str | None]: del node_params setup = prepare_order_line_render_context(session, workflow_context.context_id) state.setup = setup payload = _serialize_setup_result(setup) if setup.status == "ready": return payload, "completed", None if setup.status == "skip": return payload, "skipped", setup.reason return payload, "failed", setup.reason or "order_line_setup_failed" def _execute_resolve_template( *, session: Session, workflow_context: WorkflowContext, state: WorkflowGraphState, node_params: dict[str, Any], ) -> tuple[dict[str, Any], str, str | None]: del workflow_context, node_params if state.setup is None or not state.setup.is_ready: if state.setup is not None and state.setup.status == "skip": return _serialize_setup_result(state.setup), "skipped", state.setup.reason raise WorkflowGraphRuntimeError("resolve_template requires a ready order_line_setup result") result = resolve_order_line_template_context(session, state.setup) state.template = result return _serialize_template_result(result), "completed", None def _execute_material_map_resolve( *, session: Session, workflow_context: WorkflowContext, state: WorkflowGraphState, node_params: dict[str, Any], ) -> tuple[dict[str, Any], str, str | None]: del session, workflow_context, node_params if state.setup is None or not state.setup.is_ready: if state.setup is not None and state.setup.status == "skip": return _serialize_setup_result(state.setup), "skipped", state.setup.reason raise WorkflowGraphRuntimeError("material_map_resolve requires a ready order_line_setup result") line = state.setup.order_line cad_file = state.setup.cad_file if line is None: raise WorkflowGraphRuntimeError("material_map_resolve requires an order line") material_library = state.template.material_library if state.template is not None else None template = state.template.template if state.template is not None else None result = resolve_order_line_material_map( line, cad_file, state.setup.materials_source, material_library=material_library, template=template, ) state.materials = result return _serialize_material_result(result), "completed", None def _execute_auto_populate_materials( *, session: Session, workflow_context: WorkflowContext, state: WorkflowGraphState, node_params: dict[str, Any], ) -> tuple[dict[str, Any], str, str | None]: del workflow_context, node_params if state.setup is None or state.setup.cad_file is None: if state.setup is not None and state.setup.status == "skip": return _serialize_setup_result(state.setup), "skipped", state.setup.reason raise WorkflowGraphRuntimeError("auto_populate_materials requires a resolved cad_file") result = auto_populate_materials_for_cad(session, str(state.setup.cad_file.id)) state.auto_populate = result if state.setup.order_line is not None and state.setup.order_line.product is not None: session.refresh(state.setup.order_line.product) state.setup.materials_source = state.setup.order_line.product.cad_part_materials or [] return _serialize_auto_populate_result(result), "completed", None def _execute_glb_bbox( *, session: Session, workflow_context: WorkflowContext, state: WorkflowGraphState, node_params: dict[str, Any], ) -> tuple[dict[str, Any], str, str | None]: del session, workflow_context if state.setup is None or state.setup.cad_file is None: if state.setup is not None and state.setup.status == "skip": return _serialize_setup_result(state.setup), "skipped", state.setup.reason raise WorkflowGraphRuntimeError("glb_bbox requires a resolved cad_file") step_path = state.setup.cad_file.stored_path glb_path = node_params.get("glb_path") if glb_path is None and state.setup.glb_reuse_path is not None: glb_path = str(state.setup.glb_reuse_path) elif glb_path is None: step_file = Path(step_path) fallback_glb = step_file.parent / f"{step_file.stem}_thumbnail.glb" if fallback_glb.exists(): glb_path = str(fallback_glb) result = resolve_cad_bbox(step_path, glb_path=glb_path) state.bbox = result return _serialize_bbox_result(result), "completed", None _BRIDGE_EXECUTORS = { StepName.ORDER_LINE_SETUP: _execute_order_line_setup, StepName.RESOLVE_TEMPLATE: _execute_resolve_template, StepName.MATERIAL_MAP_RESOLVE: _execute_material_map_resolve, StepName.AUTO_POPULATE_MATERIALS: _execute_auto_populate_materials, StepName.GLB_BBOX: _execute_glb_bbox, }