feat: harden workflow graph contracts
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Pydantic schema for validated WorkflowDefinition.config JSONB.
|
||||
|
||||
A workflow config is a versioned DAG description stored as JSONB. Before
|
||||
A workflow config is a versioned DAG description stored as JSONB. Before
|
||||
being dispatched (or saved), the raw dict must pass this schema.
|
||||
|
||||
Example config::
|
||||
@@ -16,11 +16,62 @@ Example config::
|
||||
]
|
||||
}
|
||||
"""
|
||||
from typing import Literal
|
||||
from collections import deque
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from app.core.process_steps import StepName
|
||||
from app.domains.rendering.workflow_node_registry import (
|
||||
WorkflowNodeDefinition,
|
||||
WorkflowNodeFieldDefinition,
|
||||
get_node_definition,
|
||||
)
|
||||
|
||||
|
||||
def _context_seed_artifacts(definition: WorkflowNodeDefinition) -> set[str]:
|
||||
if definition.family == "order_line":
|
||||
return {"order_line_record"}
|
||||
if definition.family == "cad_file":
|
||||
return {"cad_file_record"}
|
||||
return set()
|
||||
|
||||
|
||||
def _coerce_node_label(node: "WorkflowNode") -> str:
|
||||
return f"{node.id!r} ({node.step.value})"
|
||||
|
||||
|
||||
def _validate_param_value(
|
||||
*,
|
||||
node: "WorkflowNode",
|
||||
field_definition: WorkflowNodeFieldDefinition,
|
||||
value: Any,
|
||||
) -> None:
|
||||
if value is None:
|
||||
return
|
||||
|
||||
field_label = f"node {_coerce_node_label(node)} param {field_definition.key!r}"
|
||||
|
||||
if field_definition.type == "number":
|
||||
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
||||
raise ValueError(f"{field_label} must be a number")
|
||||
numeric_value = float(value)
|
||||
if field_definition.min is not None and numeric_value < field_definition.min:
|
||||
raise ValueError(f"{field_label} must be >= {field_definition.min:g}")
|
||||
if field_definition.max is not None and numeric_value > field_definition.max:
|
||||
raise ValueError(f"{field_label} must be <= {field_definition.max:g}")
|
||||
return
|
||||
|
||||
if field_definition.type == "boolean":
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"{field_label} must be a boolean")
|
||||
return
|
||||
|
||||
if field_definition.type == "select":
|
||||
valid_values = {option.value for option in field_definition.options}
|
||||
if value not in valid_values:
|
||||
allowed_values = ", ".join(repr(option) for option in sorted(valid_values, key=repr))
|
||||
raise ValueError(f"{field_label} must be one of: {allowed_values}")
|
||||
|
||||
|
||||
class WorkflowPosition(BaseModel):
|
||||
@@ -37,7 +88,7 @@ class WorkflowNodeUI(BaseModel):
|
||||
class WorkflowNode(BaseModel):
|
||||
id: str
|
||||
step: StepName # validated against the StepName StrEnum
|
||||
params: dict = {}
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
ui: WorkflowNodeUI | None = None
|
||||
|
||||
|
||||
@@ -52,12 +103,13 @@ class WorkflowEdge(BaseModel):
|
||||
class WorkflowUI(BaseModel):
|
||||
preset: str | None = None
|
||||
execution_mode: Literal["legacy", "graph", "shadow"] | None = None
|
||||
family: Literal["cad_file", "order_line", "mixed"] | None = None
|
||||
|
||||
|
||||
class WorkflowConfig(BaseModel):
|
||||
version: int = 1
|
||||
nodes: list[WorkflowNode]
|
||||
edges: list[WorkflowEdge] = []
|
||||
edges: list[WorkflowEdge] = Field(default_factory=list)
|
||||
ui: WorkflowUI | None = None
|
||||
|
||||
@field_validator("nodes")
|
||||
@@ -93,3 +145,145 @@ class WorkflowConfig(BaseModel):
|
||||
raise ValueError(f"duplicate node id: {node.id!r}")
|
||||
seen.add(node.id)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def node_params_match_registry(self) -> "WorkflowConfig":
|
||||
for node in self.nodes:
|
||||
definition = get_node_definition(node.step)
|
||||
if definition is None:
|
||||
continue
|
||||
field_definitions = {field.key: field for field in definition.fields}
|
||||
allowed_keys = {field.key for field in definition.fields}
|
||||
unknown_keys = sorted(key for key in node.params if key not in allowed_keys)
|
||||
if unknown_keys:
|
||||
joined = ", ".join(repr(key) for key in unknown_keys)
|
||||
raise ValueError(
|
||||
f"node {node.id!r} ({node.step.value}) uses unknown param key(s): {joined}"
|
||||
)
|
||||
for key, value in node.params.items():
|
||||
field_definition = field_definitions.get(key)
|
||||
if field_definition is None:
|
||||
continue
|
||||
_validate_param_value(
|
||||
node=node,
|
||||
field_definition=field_definition,
|
||||
value=value,
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ui_family_matches_node_families(self) -> "WorkflowConfig":
|
||||
families = {
|
||||
definition.family
|
||||
for node in self.nodes
|
||||
if (definition := get_node_definition(node.step)) is not None
|
||||
}
|
||||
if not families:
|
||||
return self
|
||||
|
||||
inferred_family = "mixed" if len(families) > 1 else next(iter(families))
|
||||
execution_mode = self.ui.execution_mode if self.ui is not None else "legacy"
|
||||
if execution_mode in {"graph", "shadow"} and inferred_family == "mixed":
|
||||
raise ValueError(
|
||||
"workflow ui.execution_mode must stay single-family for graph/shadow execution"
|
||||
)
|
||||
if self.ui is None or self.ui.family is None:
|
||||
return self
|
||||
if self.ui.family != inferred_family:
|
||||
ordered_families = ", ".join(sorted(families))
|
||||
raise ValueError(
|
||||
f"workflow ui.family={self.ui.family!r} does not match node families: {ordered_families}"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def node_contracts_are_connected(self) -> "WorkflowConfig":
|
||||
execution_mode = self.ui.execution_mode if self.ui is not None else "legacy"
|
||||
if execution_mode not in {"graph", "shadow"}:
|
||||
return self
|
||||
|
||||
node_by_id = {node.id: node for node in self.nodes}
|
||||
adjacency: dict[str, list[str]] = {node.id: [] for node in self.nodes}
|
||||
in_degree: dict[str, int] = {node.id: 0 for node in self.nodes}
|
||||
available_artifacts: dict[str, set[str]] = {node.id: set() for node in self.nodes}
|
||||
|
||||
for edge in self.edges:
|
||||
adjacency[edge.from_node].append(edge.to_node)
|
||||
in_degree[edge.to_node] += 1
|
||||
|
||||
queue: deque[str] = deque(
|
||||
node_id for node_id, degree in in_degree.items() if degree == 0
|
||||
)
|
||||
processed = 0
|
||||
|
||||
while queue:
|
||||
node_id = queue.popleft()
|
||||
processed += 1
|
||||
node = node_by_id[node_id]
|
||||
definition = get_node_definition(node.step)
|
||||
if definition is None:
|
||||
continue
|
||||
|
||||
node_inputs = available_artifacts[node_id] | _context_seed_artifacts(definition)
|
||||
required = set(definition.input_contract.get("requires", []))
|
||||
missing_required = sorted(required - node_inputs)
|
||||
if missing_required:
|
||||
joined = ", ".join(repr(value) for value in missing_required)
|
||||
raise ValueError(
|
||||
f"node {_coerce_node_label(node)} is missing required input artifact(s): {joined}"
|
||||
)
|
||||
|
||||
required_any = set(definition.input_contract.get("requires_any", []))
|
||||
if required_any and not node_inputs.intersection(required_any):
|
||||
joined = ", ".join(repr(value) for value in sorted(required_any))
|
||||
raise ValueError(
|
||||
f"node {_coerce_node_label(node)} requires at least one upstream artifact from: {joined}"
|
||||
)
|
||||
|
||||
node_outputs = node_inputs | set(definition.output_contract.get("provides", []))
|
||||
for downstream_id in adjacency[node_id]:
|
||||
available_artifacts[downstream_id].update(node_outputs)
|
||||
in_degree[downstream_id] -= 1
|
||||
if in_degree[downstream_id] == 0:
|
||||
queue.append(downstream_id)
|
||||
|
||||
if processed != len(self.nodes):
|
||||
return self
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def edges_are_unique_and_acyclic(self) -> "WorkflowConfig":
|
||||
edge_pairs: set[tuple[str, str]] = set()
|
||||
adjacency: dict[str, list[str]] = {node.id: [] for node in self.nodes}
|
||||
in_degree: dict[str, int] = {node.id: 0 for node in self.nodes}
|
||||
|
||||
for edge in self.edges:
|
||||
edge_pair = (edge.from_node, edge.to_node)
|
||||
if edge.from_node == edge.to_node:
|
||||
raise ValueError(f"self-referential edge is not allowed: {edge.from_node!r}")
|
||||
if edge_pair in edge_pairs:
|
||||
raise ValueError(
|
||||
f"duplicate edge is not allowed: {edge.from_node!r} -> {edge.to_node!r}"
|
||||
)
|
||||
edge_pairs.add(edge_pair)
|
||||
adjacency[edge.from_node].append(edge.to_node)
|
||||
in_degree[edge.to_node] += 1
|
||||
|
||||
queue = [node_id for node_id, degree in in_degree.items() if degree == 0]
|
||||
processed = 0
|
||||
|
||||
while queue:
|
||||
node_id = queue.pop(0)
|
||||
processed += 1
|
||||
for neighbor in adjacency[node_id]:
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
|
||||
if processed != len(self.nodes):
|
||||
raise ValueError(
|
||||
"workflow graph must be acyclic"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user