96 lines
2.7 KiB
Python
96 lines
2.7 KiB
Python
"""Pydantic schema for validated WorkflowDefinition.config JSONB.
|
|
|
|
A workflow config is a versioned DAG description stored as JSONB. Before
|
|
being dispatched (or saved), the raw dict must pass this schema.
|
|
|
|
Example config::
|
|
|
|
{
|
|
"version": 1,
|
|
"nodes": [
|
|
{"id": "n1", "step": "resolve_step_path", "params": {}},
|
|
{"id": "n2", "step": "blender_still", "params": {"engine": "cycles"}}
|
|
],
|
|
"edges": [
|
|
{"from": "n1", "to": "n2"}
|
|
]
|
|
}
|
|
"""
|
|
from typing import Literal
|
|
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
|
from app.core.process_steps import StepName
|
|
|
|
|
|
class WorkflowPosition(BaseModel):
|
|
x: float
|
|
y: float
|
|
|
|
|
|
class WorkflowNodeUI(BaseModel):
|
|
type: str | None = None
|
|
position: WorkflowPosition | None = None
|
|
label: str | None = None
|
|
|
|
|
|
class WorkflowNode(BaseModel):
|
|
id: str
|
|
step: StepName # validated against the StepName StrEnum
|
|
params: dict = {}
|
|
ui: WorkflowNodeUI | None = None
|
|
|
|
|
|
class WorkflowEdge(BaseModel):
|
|
# "from" is a Python keyword, so we alias it.
|
|
from_node: str = Field(alias="from")
|
|
to_node: str = Field(alias="to")
|
|
|
|
model_config = {"populate_by_name": True}
|
|
|
|
|
|
class WorkflowUI(BaseModel):
|
|
preset: str | None = None
|
|
execution_mode: Literal["legacy", "graph", "shadow"] | None = None
|
|
|
|
|
|
class WorkflowConfig(BaseModel):
|
|
version: int = 1
|
|
nodes: list[WorkflowNode]
|
|
edges: list[WorkflowEdge] = []
|
|
ui: WorkflowUI | None = None
|
|
|
|
@field_validator("nodes")
|
|
@classmethod
|
|
def nodes_not_empty(cls, v: list[WorkflowNode]) -> list[WorkflowNode]:
|
|
if not v:
|
|
raise ValueError("workflow must have at least one node")
|
|
return v
|
|
|
|
@field_validator("edges")
|
|
@classmethod
|
|
def edges_reference_valid_nodes(
|
|
cls, edges: list[WorkflowEdge], info
|
|
) -> list[WorkflowEdge]:
|
|
if "nodes" in info.data:
|
|
node_ids = {n.id for n in info.data["nodes"]}
|
|
for edge in edges:
|
|
if edge.from_node not in node_ids:
|
|
raise ValueError(
|
|
f"edge references unknown node id: {edge.from_node!r}"
|
|
)
|
|
if edge.to_node not in node_ids:
|
|
raise ValueError(
|
|
f"edge references unknown node id: {edge.to_node!r}"
|
|
)
|
|
return edges
|
|
|
|
@model_validator(mode="after")
|
|
def node_ids_are_unique(self) -> "WorkflowConfig":
|
|
seen: set[str] = set()
|
|
for node in self.nodes:
|
|
if node.id in seen:
|
|
raise ValueError(f"duplicate node id: {node.id!r}")
|
|
seen.add(node.id)
|
|
return self
|