Files
HartOMat/backend/app/services/chat_service.py
T
Hartmut 48b5287baf fix: rollback DB session after failed tool execution in chat agent
When a tool like query_database fails (e.g., bad column name), the
SQLAlchemy session enters a failed transaction state. Subsequent
operations (like saving the assistant response) then also fail with
InFailedSQLTransactionError.

Fix: rollback the session in the except block of _execute_tool().
Also improved query_database tool description with correct column
names (category_key not category) to help the AI write valid SQL.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-15 14:32:33 +01:00

775 lines
30 KiB
Python

"""Chat service — Azure OpenAI with function calling for tenant AI agent.
Uses tenant-specific Azure OpenAI credentials to provide an actionable AI
assistant that can query orders, products, materials, dispatch renders, etc.
All operations are scoped to the user's tenant_id.
"""
import json
import logging
import uuid
from datetime import datetime
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
logger = logging.getLogger(__name__)
# ── System prompt ────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are the Schaeffler Automat AI assistant. You help users manage their automated render pipeline for Schaeffler product images.
You can:
- List and search orders and products
- Create new render orders
- Dispatch and check render status
- Set material overrides
- Set render overrides (format, resolution)
- Check material mapping status
- Query the database for statistics
Always be concise and helpful. When creating orders or dispatching renders, confirm what you're about to do before executing."""
# ── Tool definitions (OpenAI function-calling schema) ────────────────────────
TOOLS = [
{
"type": "function",
"function": {
"name": "list_orders",
"description": "List recent orders with render progress. Returns order number, status, line count, and render progress.",
"parameters": {
"type": "object",
"properties": {
"status": {
"type": "string",
"description": "Filter by status: draft, submitted, processing, completed, rejected. Empty for all.",
"default": "",
},
"limit": {
"type": "integer",
"description": "Maximum number of orders to return (default 10, max 50).",
"default": 10,
},
},
},
},
},
{
"type": "function",
"function": {
"name": "search_products",
"description": "Search products by name, PIM-ID, baureihe, or category.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search text (matches name, PIM-ID, baureihe).",
"default": "",
},
"category": {
"type": "string",
"description": "Filter by category key.",
"default": "",
},
"limit": {
"type": "integer",
"description": "Max results (default 20, max 50).",
"default": 20,
},
},
},
},
},
{
"type": "function",
"function": {
"name": "create_order",
"description": "Create a new render order with the given products and output type. Confirm with the user before executing.",
"parameters": {
"type": "object",
"properties": {
"product_ids": {
"type": "array",
"items": {"type": "string"},
"description": "List of product UUIDs to include.",
},
"output_type_name": {
"type": "string",
"description": "Name of the output type (e.g. 'Still Render', 'Turntable').",
},
"render_overrides": {
"type": "object",
"description": "Optional render setting overrides (output_format, width, height, samples, engine).",
},
"material_override": {
"type": "string",
"description": "Optional SCHAEFFLER library material name to apply to all lines.",
"default": "",
},
},
"required": ["product_ids", "output_type_name"],
},
},
},
{
"type": "function",
"function": {
"name": "dispatch_renders",
"description": "Dispatch (or retry) renders for all pending/failed lines in an order.",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "UUID of the order.",
},
},
"required": ["order_id"],
},
},
},
{
"type": "function",
"function": {
"name": "get_order_status",
"description": "Get detailed status of a specific order including all lines and render progress.",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "UUID of the order.",
},
},
"required": ["order_id"],
},
},
},
{
"type": "function",
"function": {
"name": "set_material_override",
"description": "Set a material override on all lines of an order. All parts will be rendered with this single material.",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "UUID of the order.",
},
"material_name": {
"type": "string",
"description": "SCHAEFFLER library material name, or empty string to clear.",
},
},
"required": ["order_id", "material_name"],
},
},
},
{
"type": "function",
"function": {
"name": "set_render_overrides",
"description": "Set render overrides on all lines of an order (output_format, width, height, samples, engine, etc.).",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "UUID of the order.",
},
"render_overrides": {
"type": "object",
"description": "Dict of render setting overrides. Pass null/empty to clear.",
},
},
"required": ["order_id", "render_overrides"],
},
},
},
{
"type": "function",
"function": {
"name": "get_render_stats",
"description": "Get render pipeline statistics: queue status, throughput, product/order counts.",
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
"function": {
"name": "check_materials",
"description": "Check if all materials in an order are mapped to library materials. Returns unmapped materials.",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "UUID of the order to check.",
},
},
"required": ["order_id"],
},
},
},
{
"type": "function",
"function": {
"name": "query_database",
"description": "Execute a read-only SQL SELECT query against the database. Key columns: products(id, name, pim_id, category_key, cad_file_id, is_active), orders(id, order_number, status, tenant_id), order_lines(id, order_id, product_id, render_status, material_override, render_overrides). Use :tenant_id parameter for tenant filtering. Category is 'category_key' not 'category'.",
"parameters": {
"type": "object",
"properties": {
"sql": {
"type": "string",
"description": "A SELECT SQL query to execute.",
},
},
"required": ["sql"],
},
},
},
]
# ── Azure OpenAI client helper ───────────────────────────────────────────────
def _resolve_credentials(tenant_config: dict | None) -> dict:
"""Resolve Azure OpenAI credentials from tenant config or global settings."""
if tenant_config and tenant_config.get("ai_enabled"):
return {
"api_key": tenant_config.get("ai_api_key") or settings.azure_openai_api_key,
"endpoint": tenant_config.get("ai_endpoint") or settings.azure_openai_endpoint,
"deployment": tenant_config.get("ai_deployment") or settings.azure_openai_deployment,
"api_version": tenant_config.get("ai_api_version") or settings.azure_openai_api_version,
"max_tokens": int(tenant_config.get("ai_max_tokens", 1000)),
"temperature": float(tenant_config.get("ai_temperature", 0.3)),
}
return {
"api_key": settings.azure_openai_api_key,
"endpoint": settings.azure_openai_endpoint,
"deployment": settings.azure_openai_deployment,
"api_version": settings.azure_openai_api_version,
"max_tokens": 1000,
"temperature": 0.3,
}
# ── Tool execution (async, tenant-scoped) ────────────────────────────────────
async def _execute_tool(
name: str,
arguments: dict,
tenant_id: str,
user_id: str,
db: AsyncSession,
) -> str:
"""Execute a tool call and return the result as a JSON string."""
try:
if name == "list_orders":
return await _tool_list_orders(db, tenant_id, **arguments)
elif name == "search_products":
return await _tool_search_products(db, tenant_id, **arguments)
elif name == "create_order":
return await _tool_create_order(db, tenant_id, user_id, **arguments)
elif name == "dispatch_renders":
return await _tool_dispatch_renders(db, tenant_id, **arguments)
elif name == "get_order_status":
return await _tool_get_order_status(db, tenant_id, **arguments)
elif name == "set_material_override":
return await _tool_set_material_override(db, tenant_id, **arguments)
elif name == "set_render_overrides":
return await _tool_set_render_overrides(db, tenant_id, **arguments)
elif name == "get_render_stats":
return await _tool_get_render_stats(db, tenant_id)
elif name == "check_materials":
return await _tool_check_materials(db, tenant_id, **arguments)
elif name == "query_database":
return await _tool_query_database(db, tenant_id, **arguments)
else:
return json.dumps({"error": f"Unknown tool: {name}"})
except Exception as exc:
logger.warning("Tool execution failed: %s(%s): %s", name, arguments, exc)
# Rollback the DB session so subsequent operations still work
try:
await db.rollback()
except Exception:
pass
return json.dumps({"error": str(exc)[:500]})
async def _tool_list_orders(db: AsyncSession, tenant_id: str, status: str = "", limit: int = 10) -> str:
limit = min(max(limit, 1), 50)
sql = """
SELECT o.id, o.order_number, o.status, o.created_at,
COUNT(ol.id) AS line_count,
COUNT(ol.id) FILTER (WHERE ol.render_status = 'completed') AS completed_lines,
COUNT(ol.id) FILTER (WHERE ol.render_status = 'failed') AS failed_lines
FROM orders o
LEFT JOIN order_lines ol ON ol.order_id = o.id
WHERE o.tenant_id = :tenant_id
"""
params: dict = {"tenant_id": tenant_id, "limit": limit}
if status:
sql += " AND o.status = :status"
params["status"] = status
sql += " GROUP BY o.id ORDER BY o.created_at DESC LIMIT :limit"
result = await db.execute(text(sql), params)
rows = result.mappings().all()
return json.dumps([dict(r) for r in rows], indent=2, default=str)
async def _tool_search_products(db: AsyncSession, tenant_id: str, query: str = "", category: str = "", limit: int = 20) -> str:
limit = min(max(limit, 1), 50)
sql = """
SELECT p.id, p.name, p.pim_id, p.category_key, p.baureihe,
p.cad_file_id IS NOT NULL AS has_step,
cf.processing_status
FROM products p
LEFT JOIN cad_files cf ON cf.id = p.cad_file_id
WHERE p.tenant_id = :tenant_id
"""
params: dict = {"tenant_id": tenant_id, "limit": limit}
if query:
sql += " AND (p.name ILIKE :q OR p.pim_id ILIKE :q OR p.baureihe ILIKE :q)"
params["q"] = f"%{query}%"
if category:
sql += " AND p.category_key = :category"
params["category"] = category
sql += " ORDER BY p.name LIMIT :limit"
result = await db.execute(text(sql), params)
rows = result.mappings().all()
return json.dumps([dict(r) for r in rows], indent=2, default=str)
async def _tool_create_order(
db: AsyncSession,
tenant_id: str,
user_id: str,
product_ids: list[str] | None = None,
output_type_name: str = "",
render_overrides: dict | None = None,
material_override: str = "",
) -> str:
"""Create an order via internal httpx call to the backend API."""
import httpx
if not product_ids:
return json.dumps({"error": "product_ids is required"})
# Resolve output type ID from name
ot_id = None
if output_type_name:
ot_result = await db.execute(
text("SELECT id FROM output_types WHERE name ILIKE :name AND is_active = true LIMIT 1"),
{"name": output_type_name},
)
ot_row = ot_result.mappings().first()
if ot_row:
ot_id = str(ot_row["id"])
else:
return json.dumps({"error": f"No active output type found matching '{output_type_name}'"})
lines = []
for pid in product_ids:
line: dict = {"product_id": pid}
if ot_id:
line["output_type_id"] = ot_id
if render_overrides:
line["render_overrides"] = render_overrides
if material_override:
line["material_override"] = material_override
lines.append(line)
# Call backend API internally using a service token
from app.utils.auth import create_access_token
token = create_access_token(user_id, "global_admin", tenant_id)
try:
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
resp = await client.post(
"/api/orders",
json={"lines": lines},
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
return json.dumps({
"order_id": data["id"],
"order_number": data["order_number"],
"status": data["status"],
"line_count": data.get("line_count", len(lines)),
}, indent=2)
except Exception as exc:
return json.dumps({"error": f"Failed to create order: {exc}"})
async def _tool_dispatch_renders(db: AsyncSession, tenant_id: str, order_id: str = "") -> str:
"""Dispatch renders via internal httpx call."""
import httpx
from app.utils.auth import create_access_token
# Verify order belongs to tenant
check = await db.execute(
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
{"oid": order_id, "tid": tenant_id},
)
if not check.first():
return json.dumps({"error": "Order not found or not in your tenant"})
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
try:
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=60) as client:
resp = await client.post(
f"/api/orders/{order_id}/dispatch-renders",
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
return json.dumps(resp.json(), indent=2, default=str)
except Exception as exc:
return json.dumps({"error": f"Failed to dispatch renders: {exc}"})
async def _tool_get_order_status(db: AsyncSession, tenant_id: str, order_id: str = "") -> str:
sql = """
SELECT o.id, o.order_number, o.status, o.created_at,
json_agg(json_build_object(
'line_id', ol.id,
'product_name', p.name,
'render_status', ol.render_status,
'output_type', ot.name,
'material_override', ol.material_override
)) AS lines
FROM orders o
LEFT JOIN order_lines ol ON ol.order_id = o.id
LEFT JOIN products p ON p.id = ol.product_id
LEFT JOIN output_types ot ON ot.id = ol.output_type_id
WHERE o.id = :oid AND o.tenant_id = :tid
GROUP BY o.id
"""
result = await db.execute(text(sql), {"oid": order_id, "tid": tenant_id})
row = result.mappings().first()
if not row:
return json.dumps({"error": "Order not found or not in your tenant"})
return json.dumps(dict(row), indent=2, default=str)
async def _tool_set_material_override(db: AsyncSession, tenant_id: str, order_id: str = "", material_name: str = "") -> str:
"""Set material override via internal httpx call."""
import httpx
from app.utils.auth import create_access_token
check = await db.execute(
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
{"oid": order_id, "tid": tenant_id},
)
if not check.first():
return json.dumps({"error": "Order not found or not in your tenant"})
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
try:
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
resp = await client.post(
f"/api/orders/{order_id}/batch-material-override",
json={"material_override": material_name or None},
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
return json.dumps(resp.json(), indent=2, default=str)
except Exception as exc:
return json.dumps({"error": f"Failed to set material override: {exc}"})
async def _tool_set_render_overrides(db: AsyncSession, tenant_id: str, order_id: str = "", render_overrides: dict | None = None) -> str:
"""Set render overrides via internal httpx call."""
import httpx
from app.utils.auth import create_access_token
check = await db.execute(
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
{"oid": order_id, "tid": tenant_id},
)
if not check.first():
return json.dumps({"error": "Order not found or not in your tenant"})
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
try:
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
resp = await client.post(
f"/api/orders/{order_id}/batch-render-overrides",
json={"render_overrides": render_overrides},
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
return json.dumps(resp.json(), indent=2, default=str)
except Exception as exc:
return json.dumps({"error": f"Failed to set render overrides: {exc}"})
async def _tool_get_render_stats(db: AsyncSession, tenant_id: str) -> str:
sql = """
SELECT
(SELECT count(*) FROM orders WHERE tenant_id = :tid) AS total_orders,
(SELECT count(*) FROM products WHERE tenant_id = :tid) AS total_products,
(SELECT count(*) FROM order_lines ol
JOIN orders o ON o.id = ol.order_id
WHERE o.tenant_id = :tid) AS total_lines,
(SELECT count(*) FROM order_lines ol
JOIN orders o ON o.id = ol.order_id
WHERE o.tenant_id = :tid AND ol.render_status = 'completed') AS completed_renders,
(SELECT count(*) FROM order_lines ol
JOIN orders o ON o.id = ol.order_id
WHERE o.tenant_id = :tid AND ol.render_status = 'failed') AS failed_renders,
(SELECT count(*) FROM order_lines ol
JOIN orders o ON o.id = ol.order_id
WHERE o.tenant_id = :tid AND ol.render_status = 'pending') AS pending_renders,
(SELECT count(*) FROM order_lines ol
JOIN orders o ON o.id = ol.order_id
WHERE o.tenant_id = :tid AND ol.render_status = 'processing') AS active_renders
"""
result = await db.execute(text(sql), {"tid": tenant_id})
row = result.mappings().first()
return json.dumps(dict(row) if row else {}, indent=2, default=str)
async def _tool_check_materials(db: AsyncSession, tenant_id: str, order_id: str = "") -> str:
"""Check unmapped materials for an order — uses internal API call."""
import httpx
from app.utils.auth import create_access_token
check = await db.execute(
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
{"oid": order_id, "tid": tenant_id},
)
if not check.first():
return json.dumps({"error": "Order not found or not in your tenant"})
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
try:
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
resp = await client.get(
f"/api/orders/{order_id}/check-materials",
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
return json.dumps(resp.json(), indent=2, default=str)
except Exception as exc:
return json.dumps({"error": f"Failed to check materials: {exc}"})
async def _tool_query_database(db: AsyncSession, tenant_id: str, sql: str = "") -> str:
"""Execute a read-only SQL query, tenant-scoped."""
sql_upper = sql.strip().upper()
if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"):
return json.dumps({"error": "Only SELECT/WITH queries are allowed (read-only)."})
for kw in ("INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"):
check = sql_upper.split("--")[0].split("/*")[0]
if f" {kw} " in f" {check} ":
return json.dumps({"error": f"{kw} statements are not allowed (read-only)."})
# Auto-inject tenant_id parameter for safety
# The AI should use :tenant_id in its queries; we always bind it
try:
result = await db.execute(text(sql), {"tenant_id": tenant_id})
rows = result.mappings().all()
if not rows:
return "Query returned 0 rows."
return json.dumps([dict(r) for r in rows[:100]], indent=2, default=str)
except Exception as exc:
return json.dumps({"error": f"Query error: {exc}"})
# ── Message persistence ──────────────────────────────────────────────────────
async def _save_message(
db: AsyncSession,
tenant_id: str,
user_id: str,
session_id: str,
role: str,
content: str,
context_type: str | None = None,
context_id: str | None = None,
token_count: int | None = None,
) -> dict:
"""Persist a chat message and return it as a dict."""
msg_id = uuid.uuid4()
now = datetime.utcnow()
await db.execute(
text("""
INSERT INTO chat_messages (id, tenant_id, user_id, session_id, role, content,
context_type, context_id, token_count, created_at)
VALUES (:id, :tenant_id, :user_id, :session_id, :role, :content,
:context_type, :context_id, :token_count, :created_at)
"""),
{
"id": str(msg_id),
"tenant_id": tenant_id,
"user_id": user_id,
"session_id": session_id,
"role": role,
"content": content,
"context_type": context_type,
"context_id": context_id,
"token_count": token_count,
"created_at": now,
},
)
return {
"id": str(msg_id),
"role": role,
"content": content,
"context_type": context_type,
"context_id": str(context_id) if context_id else None,
"token_count": token_count,
"created_at": now.isoformat(),
}
async def _load_session_messages(db: AsyncSession, session_id: str, tenant_id: str) -> list[dict]:
"""Load conversation history for a session (for context window)."""
result = await db.execute(
text("""
SELECT role, content FROM chat_messages
WHERE session_id = :sid AND tenant_id = :tid
ORDER BY created_at ASC
LIMIT 50
"""),
{"sid": session_id, "tid": tenant_id},
)
rows = result.mappings().all()
return [{"role": r["role"], "content": r["content"]} for r in rows]
# ── Main chat function ───────────────────────────────────────────────────────
async def chat_with_agent(
message: str,
session_id: str,
tenant_id: str,
user_id: str,
db: AsyncSession,
tenant_config: dict | None = None,
context_type: str | None = None,
context_id: str | None = None,
) -> dict:
"""Process a user message through the Azure OpenAI agent with function calling.
Returns {"session_id": str, "user_message": dict, "assistant_message": dict}.
"""
# Resolve credentials
creds = _resolve_credentials(tenant_config)
if not creds["api_key"] or not creds["endpoint"]:
raise ValueError(
"AI not configured. Ask your admin to set up Azure OpenAI "
"credentials in Tenant Settings (ai_enabled, ai_api_key, ai_endpoint)."
)
from openai import AzureOpenAI
client = AzureOpenAI(
api_key=creds["api_key"],
azure_endpoint=creds["endpoint"],
api_version=creds["api_version"],
)
# Build message history
history = await _load_session_messages(db, session_id, tenant_id)
# Build context-aware system prompt
system_content = SYSTEM_PROMPT
if context_type and context_id:
system_content += f"\n\nCurrent context: {context_type} {context_id}"
system_content += f"\n\nThe user's tenant_id is '{tenant_id}'. Always filter queries by this tenant_id."
messages: list[dict] = [{"role": "system", "content": system_content}]
messages.extend(history)
messages.append({"role": "user", "content": message})
# Save user message
user_msg = await _save_message(
db, tenant_id, user_id, session_id, "user", message,
context_type=context_type, context_id=context_id,
)
# OpenAI function-calling loop
max_iterations = 10
iteration = 0
total_tokens = 0
response = client.chat.completions.create(
model=creds["deployment"],
messages=messages,
tools=TOOLS,
tool_choice="auto",
max_completion_tokens=creds["max_tokens"],
temperature=creds["temperature"],
)
if response.usage:
total_tokens += response.usage.total_tokens
while response.choices[0].message.tool_calls and iteration < max_iterations:
iteration += 1
assistant_msg = response.choices[0].message
# Append assistant message with tool calls to context
messages.append({
"role": "assistant",
"content": assistant_msg.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
}
for tc in assistant_msg.tool_calls
],
})
# Execute each tool call
for tool_call in assistant_msg.tool_calls:
fn_name = tool_call.function.name
fn_args = json.loads(tool_call.function.arguments)
logger.info("Chat tool call: %s(%s) [tenant=%s]", fn_name, fn_args, tenant_id)
tool_result = await _execute_tool(fn_name, fn_args, tenant_id, user_id, db)
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
})
# Next LLM call with tool results
response = client.chat.completions.create(
model=creds["deployment"],
messages=messages,
tools=TOOLS,
tool_choice="auto",
max_completion_tokens=creds["max_tokens"],
temperature=creds["temperature"],
)
if response.usage:
total_tokens += response.usage.total_tokens
# Extract final text response
final_content = response.choices[0].message.content or ""
# Save assistant response
assistant_msg_out = await _save_message(
db, tenant_id, user_id, session_id, "assistant", final_content,
context_type=context_type, context_id=context_id,
token_count=total_tokens,
)
await db.commit()
return {
"session_id": session_id,
"user_message": user_msg,
"assistant_message": assistant_msg_out,
}