feat: add graph workflow fallback and retry metadata

This commit is contained in:
2026-04-07 10:56:45 +02:00
parent c17b7d2e8f
commit f9d4da52b9
9 changed files with 473 additions and 39 deletions
@@ -53,6 +53,17 @@ _ORDER_LINE_RENDER_STEPS = {
}
def find_unsupported_graph_nodes(workflow_context: WorkflowContext) -> list[str]:
unsupported: list[str] = []
for node in workflow_context.ordered_nodes:
if node.step in _BRIDGE_EXECUTORS:
continue
if STEP_TASK_MAP.get(node.step) is not None:
continue
unsupported.append(node.id)
return unsupported
def execute_graph_workflow(
session: Session,
workflow_context: WorkflowContext,
@@ -82,48 +93,102 @@ def execute_graph_workflow(
)
continue
retry_policy = _retry_policy(node.params)
failure_policy = _failure_policy(node.params)
metadata = _base_output(node_result.output, node)
metadata["retry_policy"] = retry_policy
metadata["failure_policy"] = failure_policy
definition = get_node_definition(node.step)
bridge_executor = _BRIDGE_EXECUTORS.get(node.step)
if bridge_executor is not None:
started = time.perf_counter()
node_result.status = "running"
node_result.output = dict(metadata)
session.flush()
try:
payload, status, log_message = bridge_executor(
session=session,
workflow_context=workflow_context,
state=state,
node_params=node.params,
)
except Exception as exc:
node_result.status = "failed"
node_result.log = str(exc)[:2000]
node_result.duration_s = round(time.perf_counter() - started, 4)
node_result.output = dict(metadata)
max_attempts = retry_policy["max_attempts"]
last_error: str | None = None
for attempt in range(1, max_attempts + 1):
started = time.perf_counter()
attempt_output = dict(metadata)
attempt_output["attempt_count"] = attempt
attempt_output["max_attempts"] = max_attempts
node_result.status = "running"
node_result.output = attempt_output
session.flush()
raise WorkflowGraphRuntimeError(
f"Node '{node.id}' ({node.step.value}) failed: {exc}"
) from exc
if payload:
metadata.update(payload)
state.node_outputs[node.id] = payload
try:
payload, status, log_message = bridge_executor(
session=session,
workflow_context=workflow_context,
state=state,
node_params=node.params,
)
except Exception as exc:
last_error = str(exc)[:2000]
if attempt < max_attempts:
retry_output = dict(attempt_output)
retry_output["last_error"] = last_error
retry_output["retry_state"] = "retrying"
node_result.status = "retrying"
node_result.log = f"Attempt {attempt}/{max_attempts} failed: {last_error}"
node_result.output = retry_output
node_result.duration_s = round(time.perf_counter() - started, 4)
session.flush()
continue
node_result.status = status
node_result.log = log_message
node_result.output = dict(metadata)
node_result.duration_s = round(time.perf_counter() - started, 4)
session.flush()
failed_output = dict(attempt_output)
failed_output["last_error"] = last_error
failed_output["retry_exhausted"] = True
node_result.status = "failed"
node_result.log = last_error
node_result.duration_s = round(time.perf_counter() - started, 4)
node_result.output = failed_output
session.flush()
raise WorkflowGraphRuntimeError(
f"Node '{node.id}' ({node.step.value}) failed: {exc}"
) from exc
if status == "failed":
raise WorkflowGraphRuntimeError(
f"Node '{node.id}' ({node.step.value}) failed: {log_message or 'unknown error'}"
)
if status == "skipped":
skipped_node_ids.append(node.id)
if payload:
metadata.update(payload)
state.node_outputs[node.id] = payload
final_output = dict(metadata)
final_output["attempt_count"] = attempt
final_output["max_attempts"] = max_attempts
if last_error is not None:
final_output["last_error"] = last_error
final_output["retry_state"] = "recovered"
node_result.status = status
node_result.log = log_message
node_result.output = final_output
node_result.duration_s = round(time.perf_counter() - started, 4)
session.flush()
if status == "failed":
last_error = (log_message or "unknown error")[:2000]
if attempt < max_attempts:
retry_output = dict(final_output)
retry_output["last_error"] = last_error
retry_output["retry_state"] = "retrying"
node_result.status = "retrying"
node_result.log = f"Attempt {attempt}/{max_attempts} failed: {last_error}"
node_result.output = retry_output
session.flush()
continue
failed_output = dict(final_output)
failed_output["last_error"] = last_error
failed_output["retry_exhausted"] = True
node_result.status = "failed"
node_result.log = last_error
node_result.output = failed_output
session.flush()
raise WorkflowGraphRuntimeError(
f"Node '{node.id}' ({node.step.value}) failed: {last_error}"
)
if status == "skipped":
skipped_node_ids.append(node.id)
break
continue
task_name = STEP_TASK_MAP.get(node.step)
@@ -147,6 +212,8 @@ def execute_graph_workflow(
metadata["task_id"] = result.id
if definition is not None:
metadata["execution_kind"] = definition.execution_kind
metadata["attempt_count"] = 1
metadata["max_attempts"] = retry_policy["max_attempts"]
node_result.status = "queued"
node_result.output = metadata
node_result.log = None
@@ -203,6 +270,29 @@ def _base_output(existing: dict[str, Any] | None, node) -> dict[str, Any]:
return metadata
def _retry_policy(node_params: dict[str, Any]) -> dict[str, Any]:
raw = node_params.get("retry_policy")
if not isinstance(raw, dict):
raw = {}
try:
max_attempts = int(raw.get("max_attempts", 1))
except (TypeError, ValueError):
max_attempts = 1
return {
"max_attempts": max(1, min(max_attempts, 5)),
}
def _failure_policy(node_params: dict[str, Any]) -> dict[str, Any]:
raw = node_params.get("failure_policy")
if not isinstance(raw, dict):
raw = {}
return {
"halt_workflow": bool(raw.get("halt_workflow", True)),
"fallback_to_legacy": bool(raw.get("fallback_to_legacy", False)),
}
def _serialize_setup_result(result: OrderLineRenderSetupResult) -> dict[str, Any]:
payload: dict[str, Any] = {
"setup_status": result.status,