Files
HartOMat/backend/app/domains/rendering/workflow_schema.py
T

93 lines
2.6 KiB
Python

"""Pydantic schema for validated WorkflowDefinition.config JSONB.
A workflow config is a versioned DAG description stored as JSONB. Before
being dispatched (or saved), the raw dict must pass this schema.
Example config::
{
"version": 1,
"nodes": [
{"id": "n1", "step": "resolve_step_path", "params": {}},
{"id": "n2", "step": "blender_still", "params": {"engine": "cycles"}}
],
"edges": [
{"from": "n1", "to": "n2"}
]
}
"""
from pydantic import BaseModel, Field, field_validator, model_validator
from app.core.process_steps import StepName
class WorkflowPosition(BaseModel):
x: float
y: float
class WorkflowNodeUI(BaseModel):
type: str | None = None
position: WorkflowPosition | None = None
label: str | None = None
class WorkflowNode(BaseModel):
id: str
step: StepName # validated against the StepName StrEnum
params: dict = {}
ui: WorkflowNodeUI | None = None
class WorkflowEdge(BaseModel):
# "from" is a Python keyword, so we alias it.
from_node: str = Field(alias="from")
to_node: str = Field(alias="to")
model_config = {"populate_by_name": True}
class WorkflowUI(BaseModel):
preset: str | None = None
class WorkflowConfig(BaseModel):
version: int = 1
nodes: list[WorkflowNode]
edges: list[WorkflowEdge] = []
ui: WorkflowUI | None = None
@field_validator("nodes")
@classmethod
def nodes_not_empty(cls, v: list[WorkflowNode]) -> list[WorkflowNode]:
if not v:
raise ValueError("workflow must have at least one node")
return v
@field_validator("edges")
@classmethod
def edges_reference_valid_nodes(
cls, edges: list[WorkflowEdge], info
) -> list[WorkflowEdge]:
if "nodes" in info.data:
node_ids = {n.id for n in info.data["nodes"]}
for edge in edges:
if edge.from_node not in node_ids:
raise ValueError(
f"edge references unknown node id: {edge.from_node!r}"
)
if edge.to_node not in node_ids:
raise ValueError(
f"edge references unknown node id: {edge.to_node!r}"
)
return edges
@model_validator(mode="after")
def node_ids_are_unique(self) -> "WorkflowConfig":
seen: set[str] = set()
for node in self.nodes:
if node.id in seen:
raise ValueError(f"duplicate node id: {node.id!r}")
seen.add(node.id)
return self