Files
HartOMat/backend/app/domains/rendering/workflow_schema.py
T

435 lines
17 KiB
Python

"""Pydantic schema for validated WorkflowDefinition.config JSONB.
A workflow config is a versioned DAG description stored as JSONB. Before
being dispatched (or saved), the raw dict must pass this schema.
Example config::
{
"version": 1,
"nodes": [
{"id": "n1", "step": "resolve_step_path", "params": {}},
{"id": "n2", "step": "blender_still", "params": {"engine": "cycles"}}
],
"edges": [
{"from": "n1", "to": "n2"}
]
}
"""
from collections import deque
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, field_validator, model_validator
from app.core.process_steps import StepName
from app.domains.rendering.workflow_node_registry import (
WorkflowNodeDefinition,
WorkflowNodeFieldDefinition,
get_node_definition,
)
_WORKFLOW_META_PARAM_KEYS = {"retry_policy", "failure_policy"}
_TEMPLATE_INPUT_PARAM_PREFIX = "template_input__"
_HEX_COLOR_LENGTHS = {7, 9}
_SAFE_FILENAME_SUFFIX_CHARS = set(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-"
)
def _context_seed_artifacts(definition: WorkflowNodeDefinition) -> set[str]:
if definition.family == "order_line":
return {"order_line_record"}
if definition.family == "cad_file":
return {"cad_file_record"}
return set()
def _infer_concrete_workflow_family(
definitions: list[WorkflowNodeDefinition],
) -> Literal["cad_file", "order_line", "mixed"] | None:
concrete_families = {
definition.family
for definition in definitions
if definition.family in {"cad_file", "order_line"}
}
if not concrete_families:
return None
if len(concrete_families) > 1:
return "mixed"
return next(iter(concrete_families))
def _coerce_node_label(node: "WorkflowNode") -> str:
return f"{node.id!r} ({node.step.value})"
def _require_node_definition(node: "WorkflowNode") -> WorkflowNodeDefinition:
definition = get_node_definition(node.step)
if definition is None:
raise ValueError(
f"node {_coerce_node_label(node)} is not registered in workflow_node_registry"
)
return definition
def _is_dynamic_template_input_param(node: "WorkflowNode", key: str) -> bool:
return (
node.step == StepName.RESOLVE_TEMPLATE
and isinstance(key, str)
and key.startswith(_TEMPLATE_INPUT_PARAM_PREFIX)
and key[len(_TEMPLATE_INPUT_PARAM_PREFIX):].strip() != ""
)
def _validate_param_value(
*,
node: "WorkflowNode",
field_definition: WorkflowNodeFieldDefinition,
value: Any,
) -> None:
if value is None:
return
field_label = f"node {_coerce_node_label(node)} param {field_definition.key!r}"
if field_definition.type == "number":
if isinstance(value, bool) or not isinstance(value, (int, float)):
raise ValueError(f"{field_label} must be a number")
numeric_value = float(value)
if field_definition.min is not None and numeric_value < field_definition.min:
raise ValueError(f"{field_label} must be >= {field_definition.min:g}")
if field_definition.max is not None and numeric_value > field_definition.max:
raise ValueError(f"{field_label} must be <= {field_definition.max:g}")
return
if field_definition.type == "boolean":
if not isinstance(value, bool):
raise ValueError(f"{field_label} must be a boolean")
return
if field_definition.type == "select":
valid_values = {option.value for option in field_definition.options}
if value not in valid_values:
allowed_values = ", ".join(repr(option) for option in sorted(valid_values, key=repr))
raise ValueError(f"{field_label} must be one of: {allowed_values}")
return
if field_definition.type == "text":
if not isinstance(value, str):
raise ValueError(f"{field_label} must be a string")
stripped_value = value.strip()
if stripped_value == "":
if field_definition.allow_blank:
return
raise ValueError(f"{field_label} may not be blank")
if field_definition.max_length is not None and len(value) > field_definition.max_length:
raise ValueError(
f"{field_label} must be at most {field_definition.max_length} characters"
)
if field_definition.text_format == "plain":
return
if field_definition.text_format == "uuid":
try:
UUID(stripped_value)
except ValueError as exc:
raise ValueError(f"{field_label} must be a valid UUID") from exc
return
if field_definition.text_format == "absolute_path":
if not stripped_value.startswith("/"):
raise ValueError(f"{field_label} must be an absolute path")
return
if field_definition.text_format == "absolute_blend_path":
if not stripped_value.startswith("/"):
raise ValueError(f"{field_label} must be an absolute path")
if not stripped_value.lower().endswith(".blend"):
raise ValueError(f"{field_label} must point to a .blend file")
return
if field_definition.text_format == "absolute_glb_path":
if not stripped_value.startswith("/"):
raise ValueError(f"{field_label} must be an absolute path")
if not stripped_value.lower().endswith(".glb"):
raise ValueError(f"{field_label} must point to a .glb file")
return
if field_definition.text_format == "float_string":
try:
float(stripped_value)
except ValueError as exc:
raise ValueError(f"{field_label} must be a valid numeric string") from exc
return
if field_definition.text_format == "hex_color":
if len(stripped_value) not in _HEX_COLOR_LENGTHS or not stripped_value.startswith("#"):
raise ValueError(f"{field_label} must be a hex color like #FFFFFF or #FFFFFFFF")
color_digits = stripped_value[1:]
if any(character not in "0123456789abcdefABCDEF" for character in color_digits):
raise ValueError(f"{field_label} must be a hex color like #FFFFFF or #FFFFFFFF")
return
if field_definition.text_format == "safe_filename_suffix":
if any(character not in _SAFE_FILENAME_SUFFIX_CHARS for character in stripped_value):
raise ValueError(
f"{field_label} may only contain letters, numbers, '.', '-' or '_'"
)
return
raise ValueError(
f"{field_label} uses unsupported text format {field_definition.text_format!r}"
)
def _validate_meta_param_value(*, node: "WorkflowNode", key: str, value: Any) -> None:
field_label = f"node {_coerce_node_label(node)} meta param {key!r}"
if key == "retry_policy":
if not isinstance(value, dict):
raise ValueError(f"{field_label} must be an object")
unknown_keys = sorted(raw_key for raw_key in value if raw_key not in {"max_attempts"})
if unknown_keys:
joined = ", ".join(repr(raw_key) for raw_key in unknown_keys)
raise ValueError(f"{field_label} uses unknown key(s): {joined}")
max_attempts = value.get("max_attempts", 1)
if isinstance(max_attempts, bool) or not isinstance(max_attempts, int):
raise ValueError(f"{field_label} field 'max_attempts' must be an integer")
if max_attempts < 1 or max_attempts > 5:
raise ValueError(f"{field_label} field 'max_attempts' must be between 1 and 5")
return
if key == "failure_policy":
if not isinstance(value, dict):
raise ValueError(f"{field_label} must be an object")
allowed_keys = {"halt_workflow", "fallback_to_legacy"}
unknown_keys = sorted(raw_key for raw_key in value if raw_key not in allowed_keys)
if unknown_keys:
joined = ", ".join(repr(raw_key) for raw_key in unknown_keys)
raise ValueError(f"{field_label} uses unknown key(s): {joined}")
for bool_key in allowed_keys:
if bool_key not in value:
continue
if not isinstance(value[bool_key], bool):
raise ValueError(f"{field_label} field {bool_key!r} must be a boolean")
return
raise ValueError(f"{field_label} is not supported")
class WorkflowPosition(BaseModel):
x: float
y: float
class WorkflowNodeUI(BaseModel):
type: str | None = None
position: WorkflowPosition | None = None
label: str | None = None
class WorkflowNode(BaseModel):
id: str
step: StepName # validated against the StepName StrEnum
params: dict[str, Any] = Field(default_factory=dict)
ui: WorkflowNodeUI | None = None
class WorkflowEdge(BaseModel):
# "from" is a Python keyword, so we alias it.
from_node: str = Field(alias="from")
to_node: str = Field(alias="to")
model_config = {"populate_by_name": True}
class WorkflowUI(BaseModel):
preset: str | None = None
execution_mode: Literal["legacy", "graph", "shadow"] | None = None
family: Literal["cad_file", "order_line", "mixed"] | None = None
class WorkflowConfig(BaseModel):
version: int = 1
nodes: list[WorkflowNode]
edges: list[WorkflowEdge] = Field(default_factory=list)
ui: WorkflowUI | None = None
@field_validator("nodes")
@classmethod
def nodes_not_empty(cls, v: list[WorkflowNode]) -> list[WorkflowNode]:
if not v:
raise ValueError("workflow must have at least one node")
return v
@field_validator("edges")
@classmethod
def edges_reference_valid_nodes(
cls, edges: list[WorkflowEdge], info
) -> list[WorkflowEdge]:
if "nodes" in info.data:
node_ids = {n.id for n in info.data["nodes"]}
for edge in edges:
if edge.from_node not in node_ids:
raise ValueError(
f"edge references unknown node id: {edge.from_node!r}"
)
if edge.to_node not in node_ids:
raise ValueError(
f"edge references unknown node id: {edge.to_node!r}"
)
return edges
@model_validator(mode="after")
def node_ids_are_unique(self) -> "WorkflowConfig":
seen: set[str] = set()
for node in self.nodes:
if node.id in seen:
raise ValueError(f"duplicate node id: {node.id!r}")
seen.add(node.id)
return self
@model_validator(mode="after")
def node_params_match_registry(self) -> "WorkflowConfig":
for node in self.nodes:
definition = _require_node_definition(node)
field_definitions = {field.key: field for field in definition.fields}
allowed_keys = {field.key for field in definition.fields} | _WORKFLOW_META_PARAM_KEYS
unknown_keys = sorted(
key
for key in node.params
if key not in allowed_keys and not _is_dynamic_template_input_param(node, key)
)
if unknown_keys:
joined = ", ".join(repr(key) for key in unknown_keys)
raise ValueError(
f"node {node.id!r} ({node.step.value}) uses unknown param key(s): {joined}"
)
for key, value in node.params.items():
if _is_dynamic_template_input_param(node, key):
continue
if key in _WORKFLOW_META_PARAM_KEYS:
_validate_meta_param_value(node=node, key=key, value=value)
continue
field_definition = field_definitions.get(key)
if field_definition is None:
continue
_validate_param_value(
node=node,
field_definition=field_definition,
value=value,
)
return self
@model_validator(mode="after")
def ui_family_matches_node_families(self) -> "WorkflowConfig":
definitions = [_require_node_definition(node) for node in self.nodes]
families = {definition.family for definition in definitions}
inferred_family = _infer_concrete_workflow_family(definitions)
if not families:
return self
execution_mode = self.ui.execution_mode if self.ui is not None else "legacy"
if execution_mode in {"graph", "shadow"} and inferred_family == "mixed":
raise ValueError(
"workflow ui.execution_mode must stay single-family for graph/shadow execution"
)
if inferred_family is None:
return self
if self.ui is None or self.ui.family is None:
return self
if self.ui.family != inferred_family:
ordered_families = ", ".join(sorted(families))
raise ValueError(
f"workflow ui.family={self.ui.family!r} does not match node families: {ordered_families}"
)
return self
@model_validator(mode="after")
def node_contracts_are_connected(self) -> "WorkflowConfig":
execution_mode = self.ui.execution_mode if self.ui is not None else "legacy"
if execution_mode not in {"graph", "shadow"}:
return self
node_by_id = {node.id: node for node in self.nodes}
adjacency: dict[str, list[str]] = {node.id: [] for node in self.nodes}
in_degree: dict[str, int] = {node.id: 0 for node in self.nodes}
available_artifacts: dict[str, set[str]] = {node.id: set() for node in self.nodes}
for edge in self.edges:
adjacency[edge.from_node].append(edge.to_node)
in_degree[edge.to_node] += 1
queue: deque[str] = deque(
node_id for node_id, degree in in_degree.items() if degree == 0
)
processed = 0
while queue:
node_id = queue.popleft()
processed += 1
node = node_by_id[node_id]
definition = _require_node_definition(node)
node_inputs = available_artifacts[node_id] | _context_seed_artifacts(definition)
required = set(definition.input_contract.get("requires", []))
missing_required = sorted(required - node_inputs)
if missing_required:
joined = ", ".join(repr(value) for value in missing_required)
raise ValueError(
f"node {_coerce_node_label(node)} is missing required input artifact(s): {joined}"
)
required_any = set(definition.input_contract.get("requires_any", []))
if required_any and not node_inputs.intersection(required_any):
joined = ", ".join(repr(value) for value in sorted(required_any))
raise ValueError(
f"node {_coerce_node_label(node)} requires at least one upstream artifact from: {joined}"
)
node_outputs = node_inputs | set(definition.output_contract.get("provides", []))
for downstream_id in adjacency[node_id]:
available_artifacts[downstream_id].update(node_outputs)
in_degree[downstream_id] -= 1
if in_degree[downstream_id] == 0:
queue.append(downstream_id)
if processed != len(self.nodes):
return self
return self
@model_validator(mode="after")
def edges_are_unique_and_acyclic(self) -> "WorkflowConfig":
edge_pairs: set[tuple[str, str]] = set()
adjacency: dict[str, list[str]] = {node.id: [] for node in self.nodes}
in_degree: dict[str, int] = {node.id: 0 for node in self.nodes}
for edge in self.edges:
edge_pair = (edge.from_node, edge.to_node)
if edge.from_node == edge.to_node:
raise ValueError(f"self-referential edge is not allowed: {edge.from_node!r}")
if edge_pair in edge_pairs:
raise ValueError(
f"duplicate edge is not allowed: {edge.from_node!r} -> {edge.to_node!r}"
)
edge_pairs.add(edge_pair)
adjacency[edge.from_node].append(edge.to_node)
in_degree[edge.to_node] += 1
queue = [node_id for node_id, degree in in_degree.items() if degree == 0]
processed = 0
while queue:
node_id = queue.pop(0)
processed += 1
for neighbor in adjacency[node_id]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if processed != len(self.nodes):
raise ValueError(
"workflow graph must be acyclic"
)
return self