435 lines
17 KiB
Python
435 lines
17 KiB
Python
"""Pydantic schema for validated WorkflowDefinition.config JSONB.
|
|
|
|
A workflow config is a versioned DAG description stored as JSONB. Before
|
|
being dispatched (or saved), the raw dict must pass this schema.
|
|
|
|
Example config::
|
|
|
|
{
|
|
"version": 1,
|
|
"nodes": [
|
|
{"id": "n1", "step": "resolve_step_path", "params": {}},
|
|
{"id": "n2", "step": "blender_still", "params": {"engine": "cycles"}}
|
|
],
|
|
"edges": [
|
|
{"from": "n1", "to": "n2"}
|
|
]
|
|
}
|
|
"""
|
|
from collections import deque
|
|
from typing import Any, Literal
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
|
from app.core.process_steps import StepName
|
|
from app.domains.rendering.workflow_node_registry import (
|
|
WorkflowNodeDefinition,
|
|
WorkflowNodeFieldDefinition,
|
|
get_node_definition,
|
|
)
|
|
|
|
|
|
_WORKFLOW_META_PARAM_KEYS = {"retry_policy", "failure_policy"}
|
|
_TEMPLATE_INPUT_PARAM_PREFIX = "template_input__"
|
|
_HEX_COLOR_LENGTHS = {7, 9}
|
|
_SAFE_FILENAME_SUFFIX_CHARS = set(
|
|
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-"
|
|
)
|
|
|
|
|
|
def _context_seed_artifacts(definition: WorkflowNodeDefinition) -> set[str]:
|
|
if definition.family == "order_line":
|
|
return {"order_line_record"}
|
|
if definition.family == "cad_file":
|
|
return {"cad_file_record"}
|
|
return set()
|
|
|
|
|
|
def _infer_concrete_workflow_family(
|
|
definitions: list[WorkflowNodeDefinition],
|
|
) -> Literal["cad_file", "order_line", "mixed"] | None:
|
|
concrete_families = {
|
|
definition.family
|
|
for definition in definitions
|
|
if definition.family in {"cad_file", "order_line"}
|
|
}
|
|
if not concrete_families:
|
|
return None
|
|
if len(concrete_families) > 1:
|
|
return "mixed"
|
|
return next(iter(concrete_families))
|
|
|
|
|
|
def _coerce_node_label(node: "WorkflowNode") -> str:
|
|
return f"{node.id!r} ({node.step.value})"
|
|
|
|
|
|
def _require_node_definition(node: "WorkflowNode") -> WorkflowNodeDefinition:
|
|
definition = get_node_definition(node.step)
|
|
if definition is None:
|
|
raise ValueError(
|
|
f"node {_coerce_node_label(node)} is not registered in workflow_node_registry"
|
|
)
|
|
return definition
|
|
|
|
|
|
def _is_dynamic_template_input_param(node: "WorkflowNode", key: str) -> bool:
|
|
return (
|
|
node.step == StepName.RESOLVE_TEMPLATE
|
|
and isinstance(key, str)
|
|
and key.startswith(_TEMPLATE_INPUT_PARAM_PREFIX)
|
|
and key[len(_TEMPLATE_INPUT_PARAM_PREFIX):].strip() != ""
|
|
)
|
|
|
|
|
|
def _validate_param_value(
|
|
*,
|
|
node: "WorkflowNode",
|
|
field_definition: WorkflowNodeFieldDefinition,
|
|
value: Any,
|
|
) -> None:
|
|
if value is None:
|
|
return
|
|
|
|
field_label = f"node {_coerce_node_label(node)} param {field_definition.key!r}"
|
|
|
|
if field_definition.type == "number":
|
|
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
|
raise ValueError(f"{field_label} must be a number")
|
|
numeric_value = float(value)
|
|
if field_definition.min is not None and numeric_value < field_definition.min:
|
|
raise ValueError(f"{field_label} must be >= {field_definition.min:g}")
|
|
if field_definition.max is not None and numeric_value > field_definition.max:
|
|
raise ValueError(f"{field_label} must be <= {field_definition.max:g}")
|
|
return
|
|
|
|
if field_definition.type == "boolean":
|
|
if not isinstance(value, bool):
|
|
raise ValueError(f"{field_label} must be a boolean")
|
|
return
|
|
|
|
if field_definition.type == "select":
|
|
valid_values = {option.value for option in field_definition.options}
|
|
if value not in valid_values:
|
|
allowed_values = ", ".join(repr(option) for option in sorted(valid_values, key=repr))
|
|
raise ValueError(f"{field_label} must be one of: {allowed_values}")
|
|
return
|
|
|
|
if field_definition.type == "text":
|
|
if not isinstance(value, str):
|
|
raise ValueError(f"{field_label} must be a string")
|
|
|
|
stripped_value = value.strip()
|
|
if stripped_value == "":
|
|
if field_definition.allow_blank:
|
|
return
|
|
raise ValueError(f"{field_label} may not be blank")
|
|
|
|
if field_definition.max_length is not None and len(value) > field_definition.max_length:
|
|
raise ValueError(
|
|
f"{field_label} must be at most {field_definition.max_length} characters"
|
|
)
|
|
|
|
if field_definition.text_format == "plain":
|
|
return
|
|
if field_definition.text_format == "uuid":
|
|
try:
|
|
UUID(stripped_value)
|
|
except ValueError as exc:
|
|
raise ValueError(f"{field_label} must be a valid UUID") from exc
|
|
return
|
|
if field_definition.text_format == "absolute_path":
|
|
if not stripped_value.startswith("/"):
|
|
raise ValueError(f"{field_label} must be an absolute path")
|
|
return
|
|
if field_definition.text_format == "absolute_blend_path":
|
|
if not stripped_value.startswith("/"):
|
|
raise ValueError(f"{field_label} must be an absolute path")
|
|
if not stripped_value.lower().endswith(".blend"):
|
|
raise ValueError(f"{field_label} must point to a .blend file")
|
|
return
|
|
if field_definition.text_format == "absolute_glb_path":
|
|
if not stripped_value.startswith("/"):
|
|
raise ValueError(f"{field_label} must be an absolute path")
|
|
if not stripped_value.lower().endswith(".glb"):
|
|
raise ValueError(f"{field_label} must point to a .glb file")
|
|
return
|
|
if field_definition.text_format == "float_string":
|
|
try:
|
|
float(stripped_value)
|
|
except ValueError as exc:
|
|
raise ValueError(f"{field_label} must be a valid numeric string") from exc
|
|
return
|
|
if field_definition.text_format == "hex_color":
|
|
if len(stripped_value) not in _HEX_COLOR_LENGTHS or not stripped_value.startswith("#"):
|
|
raise ValueError(f"{field_label} must be a hex color like #FFFFFF or #FFFFFFFF")
|
|
color_digits = stripped_value[1:]
|
|
if any(character not in "0123456789abcdefABCDEF" for character in color_digits):
|
|
raise ValueError(f"{field_label} must be a hex color like #FFFFFF or #FFFFFFFF")
|
|
return
|
|
if field_definition.text_format == "safe_filename_suffix":
|
|
if any(character not in _SAFE_FILENAME_SUFFIX_CHARS for character in stripped_value):
|
|
raise ValueError(
|
|
f"{field_label} may only contain letters, numbers, '.', '-' or '_'"
|
|
)
|
|
return
|
|
|
|
raise ValueError(
|
|
f"{field_label} uses unsupported text format {field_definition.text_format!r}"
|
|
)
|
|
|
|
|
|
def _validate_meta_param_value(*, node: "WorkflowNode", key: str, value: Any) -> None:
|
|
field_label = f"node {_coerce_node_label(node)} meta param {key!r}"
|
|
|
|
if key == "retry_policy":
|
|
if not isinstance(value, dict):
|
|
raise ValueError(f"{field_label} must be an object")
|
|
unknown_keys = sorted(raw_key for raw_key in value if raw_key not in {"max_attempts"})
|
|
if unknown_keys:
|
|
joined = ", ".join(repr(raw_key) for raw_key in unknown_keys)
|
|
raise ValueError(f"{field_label} uses unknown key(s): {joined}")
|
|
max_attempts = value.get("max_attempts", 1)
|
|
if isinstance(max_attempts, bool) or not isinstance(max_attempts, int):
|
|
raise ValueError(f"{field_label} field 'max_attempts' must be an integer")
|
|
if max_attempts < 1 or max_attempts > 5:
|
|
raise ValueError(f"{field_label} field 'max_attempts' must be between 1 and 5")
|
|
return
|
|
|
|
if key == "failure_policy":
|
|
if not isinstance(value, dict):
|
|
raise ValueError(f"{field_label} must be an object")
|
|
allowed_keys = {"halt_workflow", "fallback_to_legacy"}
|
|
unknown_keys = sorted(raw_key for raw_key in value if raw_key not in allowed_keys)
|
|
if unknown_keys:
|
|
joined = ", ".join(repr(raw_key) for raw_key in unknown_keys)
|
|
raise ValueError(f"{field_label} uses unknown key(s): {joined}")
|
|
for bool_key in allowed_keys:
|
|
if bool_key not in value:
|
|
continue
|
|
if not isinstance(value[bool_key], bool):
|
|
raise ValueError(f"{field_label} field {bool_key!r} must be a boolean")
|
|
return
|
|
|
|
raise ValueError(f"{field_label} is not supported")
|
|
|
|
|
|
class WorkflowPosition(BaseModel):
|
|
x: float
|
|
y: float
|
|
|
|
|
|
class WorkflowNodeUI(BaseModel):
|
|
type: str | None = None
|
|
position: WorkflowPosition | None = None
|
|
label: str | None = None
|
|
|
|
|
|
class WorkflowNode(BaseModel):
|
|
id: str
|
|
step: StepName # validated against the StepName StrEnum
|
|
params: dict[str, Any] = Field(default_factory=dict)
|
|
ui: WorkflowNodeUI | None = None
|
|
|
|
|
|
class WorkflowEdge(BaseModel):
|
|
# "from" is a Python keyword, so we alias it.
|
|
from_node: str = Field(alias="from")
|
|
to_node: str = Field(alias="to")
|
|
|
|
model_config = {"populate_by_name": True}
|
|
|
|
|
|
class WorkflowUI(BaseModel):
|
|
preset: str | None = None
|
|
execution_mode: Literal["legacy", "graph", "shadow"] | None = None
|
|
family: Literal["cad_file", "order_line", "mixed"] | None = None
|
|
|
|
|
|
class WorkflowConfig(BaseModel):
|
|
version: int = 1
|
|
nodes: list[WorkflowNode]
|
|
edges: list[WorkflowEdge] = Field(default_factory=list)
|
|
ui: WorkflowUI | None = None
|
|
|
|
@field_validator("nodes")
|
|
@classmethod
|
|
def nodes_not_empty(cls, v: list[WorkflowNode]) -> list[WorkflowNode]:
|
|
if not v:
|
|
raise ValueError("workflow must have at least one node")
|
|
return v
|
|
|
|
@field_validator("edges")
|
|
@classmethod
|
|
def edges_reference_valid_nodes(
|
|
cls, edges: list[WorkflowEdge], info
|
|
) -> list[WorkflowEdge]:
|
|
if "nodes" in info.data:
|
|
node_ids = {n.id for n in info.data["nodes"]}
|
|
for edge in edges:
|
|
if edge.from_node not in node_ids:
|
|
raise ValueError(
|
|
f"edge references unknown node id: {edge.from_node!r}"
|
|
)
|
|
if edge.to_node not in node_ids:
|
|
raise ValueError(
|
|
f"edge references unknown node id: {edge.to_node!r}"
|
|
)
|
|
return edges
|
|
|
|
@model_validator(mode="after")
|
|
def node_ids_are_unique(self) -> "WorkflowConfig":
|
|
seen: set[str] = set()
|
|
for node in self.nodes:
|
|
if node.id in seen:
|
|
raise ValueError(f"duplicate node id: {node.id!r}")
|
|
seen.add(node.id)
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def node_params_match_registry(self) -> "WorkflowConfig":
|
|
for node in self.nodes:
|
|
definition = _require_node_definition(node)
|
|
field_definitions = {field.key: field for field in definition.fields}
|
|
allowed_keys = {field.key for field in definition.fields} | _WORKFLOW_META_PARAM_KEYS
|
|
unknown_keys = sorted(
|
|
key
|
|
for key in node.params
|
|
if key not in allowed_keys and not _is_dynamic_template_input_param(node, key)
|
|
)
|
|
if unknown_keys:
|
|
joined = ", ".join(repr(key) for key in unknown_keys)
|
|
raise ValueError(
|
|
f"node {node.id!r} ({node.step.value}) uses unknown param key(s): {joined}"
|
|
)
|
|
for key, value in node.params.items():
|
|
if _is_dynamic_template_input_param(node, key):
|
|
continue
|
|
if key in _WORKFLOW_META_PARAM_KEYS:
|
|
_validate_meta_param_value(node=node, key=key, value=value)
|
|
continue
|
|
field_definition = field_definitions.get(key)
|
|
if field_definition is None:
|
|
continue
|
|
_validate_param_value(
|
|
node=node,
|
|
field_definition=field_definition,
|
|
value=value,
|
|
)
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def ui_family_matches_node_families(self) -> "WorkflowConfig":
|
|
definitions = [_require_node_definition(node) for node in self.nodes]
|
|
families = {definition.family for definition in definitions}
|
|
inferred_family = _infer_concrete_workflow_family(definitions)
|
|
if not families:
|
|
return self
|
|
|
|
execution_mode = self.ui.execution_mode if self.ui is not None else "legacy"
|
|
if execution_mode in {"graph", "shadow"} and inferred_family == "mixed":
|
|
raise ValueError(
|
|
"workflow ui.execution_mode must stay single-family for graph/shadow execution"
|
|
)
|
|
if inferred_family is None:
|
|
return self
|
|
if self.ui is None or self.ui.family is None:
|
|
return self
|
|
if self.ui.family != inferred_family:
|
|
ordered_families = ", ".join(sorted(families))
|
|
raise ValueError(
|
|
f"workflow ui.family={self.ui.family!r} does not match node families: {ordered_families}"
|
|
)
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def node_contracts_are_connected(self) -> "WorkflowConfig":
|
|
execution_mode = self.ui.execution_mode if self.ui is not None else "legacy"
|
|
if execution_mode not in {"graph", "shadow"}:
|
|
return self
|
|
|
|
node_by_id = {node.id: node for node in self.nodes}
|
|
adjacency: dict[str, list[str]] = {node.id: [] for node in self.nodes}
|
|
in_degree: dict[str, int] = {node.id: 0 for node in self.nodes}
|
|
available_artifacts: dict[str, set[str]] = {node.id: set() for node in self.nodes}
|
|
|
|
for edge in self.edges:
|
|
adjacency[edge.from_node].append(edge.to_node)
|
|
in_degree[edge.to_node] += 1
|
|
|
|
queue: deque[str] = deque(
|
|
node_id for node_id, degree in in_degree.items() if degree == 0
|
|
)
|
|
processed = 0
|
|
|
|
while queue:
|
|
node_id = queue.popleft()
|
|
processed += 1
|
|
node = node_by_id[node_id]
|
|
definition = _require_node_definition(node)
|
|
|
|
node_inputs = available_artifacts[node_id] | _context_seed_artifacts(definition)
|
|
required = set(definition.input_contract.get("requires", []))
|
|
missing_required = sorted(required - node_inputs)
|
|
if missing_required:
|
|
joined = ", ".join(repr(value) for value in missing_required)
|
|
raise ValueError(
|
|
f"node {_coerce_node_label(node)} is missing required input artifact(s): {joined}"
|
|
)
|
|
|
|
required_any = set(definition.input_contract.get("requires_any", []))
|
|
if required_any and not node_inputs.intersection(required_any):
|
|
joined = ", ".join(repr(value) for value in sorted(required_any))
|
|
raise ValueError(
|
|
f"node {_coerce_node_label(node)} requires at least one upstream artifact from: {joined}"
|
|
)
|
|
|
|
node_outputs = node_inputs | set(definition.output_contract.get("provides", []))
|
|
for downstream_id in adjacency[node_id]:
|
|
available_artifacts[downstream_id].update(node_outputs)
|
|
in_degree[downstream_id] -= 1
|
|
if in_degree[downstream_id] == 0:
|
|
queue.append(downstream_id)
|
|
|
|
if processed != len(self.nodes):
|
|
return self
|
|
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def edges_are_unique_and_acyclic(self) -> "WorkflowConfig":
|
|
edge_pairs: set[tuple[str, str]] = set()
|
|
adjacency: dict[str, list[str]] = {node.id: [] for node in self.nodes}
|
|
in_degree: dict[str, int] = {node.id: 0 for node in self.nodes}
|
|
|
|
for edge in self.edges:
|
|
edge_pair = (edge.from_node, edge.to_node)
|
|
if edge.from_node == edge.to_node:
|
|
raise ValueError(f"self-referential edge is not allowed: {edge.from_node!r}")
|
|
if edge_pair in edge_pairs:
|
|
raise ValueError(
|
|
f"duplicate edge is not allowed: {edge.from_node!r} -> {edge.to_node!r}"
|
|
)
|
|
edge_pairs.add(edge_pair)
|
|
adjacency[edge.from_node].append(edge.to_node)
|
|
in_degree[edge.to_node] += 1
|
|
|
|
queue = [node_id for node_id, degree in in_degree.items() if degree == 0]
|
|
processed = 0
|
|
|
|
while queue:
|
|
node_id = queue.pop(0)
|
|
processed += 1
|
|
for neighbor in adjacency[node_id]:
|
|
in_degree[neighbor] -= 1
|
|
if in_degree[neighbor] == 0:
|
|
queue.append(neighbor)
|
|
|
|
if processed != len(self.nodes):
|
|
raise ValueError(
|
|
"workflow graph must be acyclic"
|
|
)
|
|
|
|
return self
|