Files
HartOMat/backend/app/domains/rendering/workflow_graph_runtime.py
T

487 lines
19 KiB
Python

from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from app.core.process_steps import StepName
from app.domains.rendering.models import WorkflowNodeResult, WorkflowRun
from app.domains.rendering.workflow_executor import STEP_TASK_MAP, WorkflowContext, WorkflowDispatchResult
from app.domains.rendering.workflow_node_registry import get_node_definition
from app.domains.rendering.workflow_runtime_services import (
AutoPopulateMaterialsResult,
BBoxResolutionResult,
MaterialResolutionResult,
OrderLineRenderSetupResult,
TemplateResolutionResult,
auto_populate_materials_for_cad,
prepare_order_line_render_context,
resolve_cad_bbox,
resolve_order_line_material_map,
resolve_order_line_template_context,
)
logger = logging.getLogger(__name__)
class WorkflowGraphRuntimeError(RuntimeError):
pass
@dataclass(slots=True)
class WorkflowGraphState:
setup: OrderLineRenderSetupResult | None = None
template: TemplateResolutionResult | None = None
materials: MaterialResolutionResult | None = None
auto_populate: AutoPopulateMaterialsResult | None = None
bbox: BBoxResolutionResult | None = None
node_outputs: dict[str, dict[str, Any]] = field(default_factory=dict)
_ORDER_LINE_RENDER_STEPS = {
StepName.BLENDER_STILL,
StepName.BLENDER_TURNTABLE,
StepName.EXPORT_BLEND,
StepName.OUTPUT_SAVE,
StepName.NOTIFY,
}
def find_unsupported_graph_nodes(workflow_context: WorkflowContext) -> list[str]:
unsupported: list[str] = []
for node in workflow_context.ordered_nodes:
if node.step in _BRIDGE_EXECUTORS:
continue
if STEP_TASK_MAP.get(node.step) is not None:
continue
unsupported.append(node.id)
return unsupported
def execute_graph_workflow(
session: Session,
workflow_context: WorkflowContext,
) -> WorkflowDispatchResult:
if workflow_context.workflow_run_id is None:
raise ValueError("workflow_context.workflow_run_id is required for graph execution")
run = session.execute(
select(WorkflowRun)
.where(WorkflowRun.id == workflow_context.workflow_run_id)
.options(selectinload(WorkflowRun.node_results))
).scalar_one()
node_results = {node_result.node_name: node_result for node_result in run.node_results}
state = WorkflowGraphState()
task_ids: list[str] = []
node_task_ids: dict[str, str] = {}
skipped_node_ids: list[str] = []
for node in workflow_context.ordered_nodes:
node_result = node_results.get(node.id)
if node_result is None:
logger.warning(
"[WORKFLOW] Missing WorkflowNodeResult row for node %s on run %s",
node.id,
run.id,
)
continue
retry_policy = _retry_policy(node.params)
failure_policy = _failure_policy(node.params)
metadata = _base_output(node_result.output, node)
metadata["retry_policy"] = retry_policy
metadata["failure_policy"] = failure_policy
definition = get_node_definition(node.step)
bridge_executor = _BRIDGE_EXECUTORS.get(node.step)
if bridge_executor is not None:
max_attempts = retry_policy["max_attempts"]
last_error: str | None = None
for attempt in range(1, max_attempts + 1):
started = time.perf_counter()
attempt_output = dict(metadata)
attempt_output["attempt_count"] = attempt
attempt_output["max_attempts"] = max_attempts
node_result.status = "running"
node_result.output = attempt_output
session.flush()
try:
payload, status, log_message = bridge_executor(
session=session,
workflow_context=workflow_context,
state=state,
node_params=node.params,
)
except Exception as exc:
last_error = str(exc)[:2000]
if attempt < max_attempts:
retry_output = dict(attempt_output)
retry_output["last_error"] = last_error
retry_output["retry_state"] = "retrying"
node_result.status = "retrying"
node_result.log = f"Attempt {attempt}/{max_attempts} failed: {last_error}"
node_result.output = retry_output
node_result.duration_s = round(time.perf_counter() - started, 4)
session.flush()
continue
failed_output = dict(attempt_output)
failed_output["last_error"] = last_error
failed_output["retry_exhausted"] = True
node_result.status = "failed"
node_result.log = last_error
node_result.duration_s = round(time.perf_counter() - started, 4)
node_result.output = failed_output
session.flush()
raise WorkflowGraphRuntimeError(
f"Node '{node.id}' ({node.step.value}) failed: {exc}"
) from exc
if payload:
metadata.update(payload)
state.node_outputs[node.id] = payload
final_output = dict(metadata)
final_output["attempt_count"] = attempt
final_output["max_attempts"] = max_attempts
if last_error is not None:
final_output["last_error"] = last_error
final_output["retry_state"] = "recovered"
node_result.status = status
node_result.log = log_message
node_result.output = final_output
node_result.duration_s = round(time.perf_counter() - started, 4)
session.flush()
if status == "failed":
last_error = (log_message or "unknown error")[:2000]
if attempt < max_attempts:
retry_output = dict(final_output)
retry_output["last_error"] = last_error
retry_output["retry_state"] = "retrying"
node_result.status = "retrying"
node_result.log = f"Attempt {attempt}/{max_attempts} failed: {last_error}"
node_result.output = retry_output
session.flush()
continue
failed_output = dict(final_output)
failed_output["last_error"] = last_error
failed_output["retry_exhausted"] = True
node_result.status = "failed"
node_result.log = last_error
node_result.output = failed_output
session.flush()
raise WorkflowGraphRuntimeError(
f"Node '{node.id}' ({node.step.value}) failed: {last_error}"
)
if status == "skipped":
skipped_node_ids.append(node.id)
break
continue
task_name = STEP_TASK_MAP.get(node.step)
if task_name is not None:
if node.step in _ORDER_LINE_RENDER_STEPS and state.setup is not None and not state.setup.is_ready:
metadata["blocked_by"] = "order_line_setup"
node_result.status = "skipped"
node_result.output = metadata
node_result.log = (
f"Skipped because order_line_setup did not complete successfully "
f"({state.setup.status})"
)
node_result.duration_s = None
session.flush()
skipped_node_ids.append(node.id)
continue
from app.tasks.celery_app import celery_app
result = celery_app.send_task(task_name, args=[workflow_context.context_id], kwargs=node.params)
metadata["task_id"] = result.id
if definition is not None:
metadata["execution_kind"] = definition.execution_kind
metadata["attempt_count"] = 1
metadata["max_attempts"] = retry_policy["max_attempts"]
node_result.status = "queued"
node_result.output = metadata
node_result.log = None
node_result.duration_s = None
session.flush()
task_ids.append(result.id)
node_task_ids[node.id] = result.id
logger.info(
"[WORKFLOW] Dispatched node %r (step=%s, mode=%s, run=%s) -> Celery task %s",
node.id,
node.step,
workflow_context.execution_mode,
workflow_context.workflow_run_id,
result.id,
)
continue
metadata["execution_kind"] = definition.execution_kind if definition is not None else "bridge"
node_result.status = "skipped"
node_result.output = metadata
node_result.log = f"Graph runtime not implemented for step '{node.step.value}'"
node_result.duration_s = None
session.flush()
skipped_node_ids.append(node.id)
run.celery_task_id = task_ids[0] if task_ids else None
if any(node_result.status == "failed" for node_result in run.node_results):
run.status = "failed"
run.completed_at = datetime.utcnow()
elif task_ids:
run.status = "pending"
run.completed_at = None
else:
run.status = "completed"
run.completed_at = datetime.utcnow()
session.flush()
return WorkflowDispatchResult(
context=workflow_context,
task_ids=task_ids,
node_task_ids=node_task_ids,
skipped_node_ids=skipped_node_ids,
)
def _base_output(existing: dict[str, Any] | None, node) -> dict[str, Any]:
metadata = dict(existing or {})
metadata.setdefault("step", node.step.value)
if node.ui and node.ui.label:
metadata.setdefault("label", node.ui.label)
definition = get_node_definition(node.step)
if definition is not None:
metadata.setdefault("execution_kind", definition.execution_kind)
return metadata
def _retry_policy(node_params: dict[str, Any]) -> dict[str, Any]:
raw = node_params.get("retry_policy")
if not isinstance(raw, dict):
raw = {}
try:
max_attempts = int(raw.get("max_attempts", 1))
except (TypeError, ValueError):
max_attempts = 1
return {
"max_attempts": max(1, min(max_attempts, 5)),
}
def _failure_policy(node_params: dict[str, Any]) -> dict[str, Any]:
raw = node_params.get("failure_policy")
if not isinstance(raw, dict):
raw = {}
return {
"halt_workflow": bool(raw.get("halt_workflow", True)),
"fallback_to_legacy": bool(raw.get("fallback_to_legacy", False)),
}
def _serialize_setup_result(result: OrderLineRenderSetupResult) -> dict[str, Any]:
payload: dict[str, Any] = {
"setup_status": result.status,
"reason": result.reason,
"materials_source_count": len(result.materials_source or []),
"part_colors_count": len(result.part_colors or {}),
"usd_render_path": str(result.usd_render_path) if result.usd_render_path else None,
"glb_reuse_path": str(result.glb_reuse_path) if result.glb_reuse_path else None,
}
if result.order_line is not None:
payload["order_line_id"] = str(result.order_line.id)
payload["product_id"] = str(result.order_line.product_id) if result.order_line.product_id else None
payload["output_type_id"] = str(result.order_line.output_type_id) if result.order_line.output_type_id else None
if result.order is not None:
payload["order_id"] = str(result.order.id)
payload["order_status"] = result.order.status.value if getattr(result.order, "status", None) else None
if result.cad_file is not None:
payload["cad_file_id"] = str(result.cad_file.id)
payload["step_path"] = result.cad_file.stored_path
return payload
def _serialize_template_result(result: TemplateResolutionResult) -> dict[str, Any]:
return {
"template_id": str(result.template.id) if result.template is not None else None,
"template_name": result.template.name if result.template is not None else None,
"template_path": result.template.blend_file_path if result.template is not None else None,
"material_library": result.material_library,
"material_map": result.material_map,
"material_map_count": len(result.material_map or {}),
"use_materials": result.use_materials,
"override_material": result.override_material,
"category_key": result.category_key,
"output_type_id": result.output_type_id,
}
def _serialize_material_result(result: MaterialResolutionResult) -> dict[str, Any]:
return {
"material_map": result.material_map,
"material_map_count": len(result.material_map or {}),
"use_materials": result.use_materials,
"override_material": result.override_material,
"source_material_count": result.source_material_count,
"resolved_material_count": result.resolved_material_count,
}
def _serialize_auto_populate_result(result: AutoPopulateMaterialsResult) -> dict[str, Any]:
return {
"cad_file_id": result.cad_file_id,
"updated_product_ids": result.updated_product_ids,
"updated_product_count": len(result.updated_product_ids),
"queued_thumbnail_regeneration": result.queued_thumbnail_regeneration,
"part_colors": result.part_colors,
"part_colors_count": len(result.part_colors or {}),
"cad_parts": result.cad_parts,
}
def _serialize_bbox_result(result: BBoxResolutionResult) -> dict[str, Any]:
return {
"bbox_data": result.bbox_data,
"has_bbox": result.has_bbox,
"source_kind": result.source_kind,
"step_path": result.step_path,
"glb_path": result.glb_path,
}
def _execute_order_line_setup(
*,
session: Session,
workflow_context: WorkflowContext,
state: WorkflowGraphState,
node_params: dict[str, Any],
) -> tuple[dict[str, Any], str, str | None]:
del node_params
setup = prepare_order_line_render_context(session, workflow_context.context_id)
state.setup = setup
payload = _serialize_setup_result(setup)
if setup.status == "ready":
return payload, "completed", None
if setup.status == "skip":
return payload, "skipped", setup.reason
return payload, "failed", setup.reason or "order_line_setup_failed"
def _execute_resolve_template(
*,
session: Session,
workflow_context: WorkflowContext,
state: WorkflowGraphState,
node_params: dict[str, Any],
) -> tuple[dict[str, Any], str, str | None]:
del workflow_context, node_params
if state.setup is None or not state.setup.is_ready:
if state.setup is not None and state.setup.status == "skip":
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
raise WorkflowGraphRuntimeError("resolve_template requires a ready order_line_setup result")
result = resolve_order_line_template_context(session, state.setup)
state.template = result
return _serialize_template_result(result), "completed", None
def _execute_material_map_resolve(
*,
session: Session,
workflow_context: WorkflowContext,
state: WorkflowGraphState,
node_params: dict[str, Any],
) -> tuple[dict[str, Any], str, str | None]:
del session, workflow_context, node_params
if state.setup is None or not state.setup.is_ready:
if state.setup is not None and state.setup.status == "skip":
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
raise WorkflowGraphRuntimeError("material_map_resolve requires a ready order_line_setup result")
line = state.setup.order_line
cad_file = state.setup.cad_file
if line is None:
raise WorkflowGraphRuntimeError("material_map_resolve requires an order line")
material_library = state.template.material_library if state.template is not None else None
template = state.template.template if state.template is not None else None
result = resolve_order_line_material_map(
line,
cad_file,
state.setup.materials_source,
material_library=material_library,
template=template,
)
state.materials = result
return _serialize_material_result(result), "completed", None
def _execute_auto_populate_materials(
*,
session: Session,
workflow_context: WorkflowContext,
state: WorkflowGraphState,
node_params: dict[str, Any],
) -> tuple[dict[str, Any], str, str | None]:
del workflow_context, node_params
if state.setup is None or state.setup.cad_file is None:
if state.setup is not None and state.setup.status == "skip":
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
raise WorkflowGraphRuntimeError("auto_populate_materials requires a resolved cad_file")
result = auto_populate_materials_for_cad(session, str(state.setup.cad_file.id))
state.auto_populate = result
if state.setup.order_line is not None and state.setup.order_line.product is not None:
session.refresh(state.setup.order_line.product)
state.setup.materials_source = state.setup.order_line.product.cad_part_materials or []
return _serialize_auto_populate_result(result), "completed", None
def _execute_glb_bbox(
*,
session: Session,
workflow_context: WorkflowContext,
state: WorkflowGraphState,
node_params: dict[str, Any],
) -> tuple[dict[str, Any], str, str | None]:
del session, workflow_context
if state.setup is None or state.setup.cad_file is None:
if state.setup is not None and state.setup.status == "skip":
return _serialize_setup_result(state.setup), "skipped", state.setup.reason
raise WorkflowGraphRuntimeError("glb_bbox requires a resolved cad_file")
step_path = state.setup.cad_file.stored_path
glb_path = node_params.get("glb_path")
if glb_path is None and state.setup.glb_reuse_path is not None:
glb_path = str(state.setup.glb_reuse_path)
elif glb_path is None:
step_file = Path(step_path)
fallback_glb = step_file.parent / f"{step_file.stem}_thumbnail.glb"
if fallback_glb.exists():
glb_path = str(fallback_glb)
result = resolve_cad_bbox(step_path, glb_path=glb_path)
state.bbox = result
return _serialize_bbox_result(result), "completed", None
_BRIDGE_EXECUTORS = {
StepName.ORDER_LINE_SETUP: _execute_order_line_setup,
StepName.RESOLVE_TEMPLATE: _execute_resolve_template,
StepName.MATERIAL_MAP_RESOLVE: _execute_material_map_resolve,
StepName.AUTO_POPULATE_MATERIALS: _execute_auto_populate_materials,
StepName.GLB_BBOX: _execute_glb_bbox,
}