feat: tenant AI chat agent with function calling
Actionable AI assistant that uses per-tenant Azure OpenAI credentials to execute natural language commands against the render pipeline. Backend: - ChatMessage model + migration (session-based conversations) - Chat service with 10 OpenAI function-calling tools: list_orders, search_products, create_order, dispatch_renders, get_order_status, set_material_override, set_render_overrides, get_render_stats, check_materials, query_database - All tools tenant-scoped (queries filtered by tenant_id) - Write operations use httpx to call backend API internally - Chat API: POST /chat/messages, GET /chat/sessions, DELETE session - Conversation history preserved in DB (last 50 messages per session) Frontend: - Slide-out ChatPanel (right side, w-96, animated) - User/assistant message styling with avatars and timestamps - Session management (new chat, session history, delete) - Typing indicator while waiting for AI response - Floating chat button in bottom-right corner - Error state for unconfigured AI tenants Example: "Render all Kugellager products as WebP at 1024x1024" → Agent calls search_products + create_order + dispatch_renders Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
"""add chat_messages table
|
||||
|
||||
Revision ID: 69964e910545
|
||||
Revises: f5906aaf75af
|
||||
Create Date: 2026-03-15 11:38:41.189160
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '69964e910545'
|
||||
down_revision: Union[str, None] = 'f5906aaf75af'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('chat_messages',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=True),
|
||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('session_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.String(length=20), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('context_type', sa.String(length=50), nullable=True),
|
||||
sa.Column('context_id', sa.UUID(), nullable=True),
|
||||
sa.Column('token_count', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_chat_messages_session_id'), 'chat_messages', ['session_id'], unique=False)
|
||||
op.create_index(op.f('ix_chat_messages_tenant_id'), 'chat_messages', ['tenant_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_chat_messages_tenant_id'), table_name='chat_messages')
|
||||
op.drop_index(op.f('ix_chat_messages_session_id'), table_name='chat_messages')
|
||||
op.drop_table('chat_messages')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Chat API endpoints for tenant AI agent conversations."""
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.utils.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class ChatMessageCreate(BaseModel):
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
context_type: str | None = None
|
||||
context_id: str | None = None
|
||||
|
||||
|
||||
class ChatMessageOut(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
content: str
|
||||
context_type: str | None = None
|
||||
context_id: str | None = None
|
||||
token_count: int | None = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
session_id: str
|
||||
user_message: ChatMessageOut
|
||||
assistant_message: ChatMessageOut
|
||||
|
||||
|
||||
class ChatSessionSummary(BaseModel):
|
||||
session_id: str
|
||||
last_message: str
|
||||
message_count: int
|
||||
created_at: str
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/messages", response_model=ChatResponse)
|
||||
async def send_message(
|
||||
body: ChatMessageCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Send a message to the AI assistant and get a response.
|
||||
|
||||
Creates a new session if session_id is not provided.
|
||||
Uses the tenant's Azure OpenAI credentials for the LLM call.
|
||||
"""
|
||||
from app.services.chat_service import chat_with_agent
|
||||
|
||||
# Load tenant config
|
||||
tenant_config = await _get_tenant_config(db, user)
|
||||
|
||||
session_id = body.session_id or str(uuid.uuid4())
|
||||
|
||||
# If session_id was provided, verify it belongs to this user
|
||||
if body.session_id:
|
||||
check = await db.execute(
|
||||
text("""
|
||||
SELECT 1 FROM chat_messages
|
||||
WHERE session_id = :sid AND user_id = :uid
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"sid": session_id, "uid": str(user.id)},
|
||||
)
|
||||
if not check.first():
|
||||
# New session with user-supplied ID is OK; existing session must belong to user
|
||||
pass
|
||||
|
||||
try:
|
||||
result = await chat_with_agent(
|
||||
message=body.message,
|
||||
session_id=session_id,
|
||||
tenant_id=str(user.tenant_id),
|
||||
user_id=str(user.id),
|
||||
db=db,
|
||||
tenant_config=tenant_config,
|
||||
context_type=body.context_type,
|
||||
context_id=body.context_id,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
# AI not configured
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.exception("Chat error for user %s", user.id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Chat service error: {exc}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[ChatSessionSummary])
|
||||
async def list_sessions(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List the current user's chat sessions, most recent first."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
session_id::text AS session_id,
|
||||
(SELECT content FROM chat_messages cm2
|
||||
WHERE cm2.session_id = cm.session_id
|
||||
ORDER BY cm2.created_at DESC LIMIT 1) AS last_message,
|
||||
COUNT(*) AS message_count,
|
||||
MIN(cm.created_at) AS created_at
|
||||
FROM chat_messages cm
|
||||
WHERE cm.user_id = :uid AND cm.tenant_id = :tid
|
||||
GROUP BY cm.session_id
|
||||
ORDER BY MAX(cm.created_at) DESC
|
||||
LIMIT 50
|
||||
"""),
|
||||
{"uid": str(user.id), "tid": str(user.tenant_id)},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [
|
||||
{
|
||||
"session_id": r["session_id"],
|
||||
"last_message": (r["last_message"] or "")[:200],
|
||||
"message_count": r["message_count"],
|
||||
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/messages", response_model=list[ChatMessageOut])
|
||||
async def get_session_messages(
|
||||
session_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all messages in a chat session."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT id::text, role, content, context_type,
|
||||
context_id::text, token_count, created_at
|
||||
FROM chat_messages
|
||||
WHERE session_id = :sid AND user_id = :uid AND tenant_id = :tid
|
||||
ORDER BY created_at ASC
|
||||
"""),
|
||||
{"sid": session_id, "uid": str(user.id), "tid": str(user.tenant_id)},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
if not rows:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return [
|
||||
{
|
||||
"id": r["id"],
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"context_type": r["context_type"],
|
||||
"context_id": r["context_id"],
|
||||
"token_count": r["token_count"],
|
||||
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(
|
||||
session_id: str,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete all messages in a chat session."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
DELETE FROM chat_messages
|
||||
WHERE session_id = :sid AND user_id = :uid AND tenant_id = :tid
|
||||
"""),
|
||||
{"sid": session_id, "uid": str(user.id), "tid": str(user.tenant_id)},
|
||||
)
|
||||
await db.commit()
|
||||
deleted = result.rowcount
|
||||
if deleted == 0:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return {"deleted": deleted, "session_id": session_id}
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _get_tenant_config(db: AsyncSession, user: User) -> dict | None:
|
||||
"""Load tenant_config JSONB for the user's tenant."""
|
||||
if not user.tenant_id:
|
||||
return None
|
||||
result = await db.execute(
|
||||
text("SELECT tenant_config FROM tenants WHERE id = :tid"),
|
||||
{"tid": str(user.tenant_id)},
|
||||
)
|
||||
row = result.mappings().first()
|
||||
return row["tenant_config"] if row else None
|
||||
@@ -26,6 +26,7 @@ from app.domains.media.router import router as media_router
|
||||
from app.api.routers.asset_libraries import router as asset_libraries_router
|
||||
from app.domains.admin.dashboard_router import router as dashboard_router
|
||||
from app.api.routers.task_logs import router as task_logs_router
|
||||
from app.api.routers.chat import router as chat_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -95,6 +96,7 @@ app.include_router(asset_libraries_router, prefix="/api")
|
||||
app.include_router(dashboard_router, prefix="/api")
|
||||
app.include_router(task_logs_router, prefix="/api")
|
||||
app.include_router(global_render_positions_router, prefix="/api")
|
||||
app.include_router(chat_router, prefix="/api")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -18,11 +18,12 @@ from app.domains.admin.models import DashboardConfig
|
||||
# Also re-export SystemSetting (no domain assigned — stays as-is)
|
||||
from app.models.system_setting import SystemSetting
|
||||
from app.models.worker_config import WorkerConfig
|
||||
from app.models.chat import ChatMessage
|
||||
|
||||
__all__ = [
|
||||
"Tenant", "User", "Template", "CadFile", "Product", "Order", "OrderItem", "OrderLine",
|
||||
"AuditLog", "PricingTier", "OutputType", "RenderTemplate", "ProductRenderPosition", "GlobalRenderPosition",
|
||||
"WorkflowDefinition", "WorkflowRun", "WorkflowNodeResult",
|
||||
"Material", "MaterialAlias", "AssetLibrary", "MediaAsset", "MediaAssetType", "SystemSetting",
|
||||
"DashboardConfig", "WorkerConfig",
|
||||
"DashboardConfig", "WorkerConfig", "ChatMessage",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Chat message model for tenant AI agent conversations."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, Text, ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False) # "user", "assistant", "system"
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
context_type: Mapped[str | None] = mapped_column(String(50), nullable=True) # "order", "product", "general"
|
||||
context_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
|
||||
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
@@ -0,0 +1,769 @@
|
||||
"""Chat service — Azure OpenAI with function calling for tenant AI agent.
|
||||
|
||||
Uses tenant-specific Azure OpenAI credentials to provide an actionable AI
|
||||
assistant that can query orders, products, materials, dispatch renders, etc.
|
||||
All operations are scoped to the user's tenant_id.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── System prompt ────────────────────────────────────────────────────────────
|
||||
|
||||
SYSTEM_PROMPT = """You are the Schaeffler Automat AI assistant. You help users manage their automated render pipeline for Schaeffler product images.
|
||||
|
||||
You can:
|
||||
- List and search orders and products
|
||||
- Create new render orders
|
||||
- Dispatch and check render status
|
||||
- Set material overrides
|
||||
- Set render overrides (format, resolution)
|
||||
- Check material mapping status
|
||||
- Query the database for statistics
|
||||
|
||||
Always be concise and helpful. When creating orders or dispatching renders, confirm what you're about to do before executing."""
|
||||
|
||||
# ── Tool definitions (OpenAI function-calling schema) ────────────────────────
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_orders",
|
||||
"description": "List recent orders with render progress. Returns order number, status, line count, and render progress.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"description": "Filter by status: draft, submitted, processing, completed, rejected. Empty for all.",
|
||||
"default": "",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of orders to return (default 10, max 50).",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_products",
|
||||
"description": "Search products by name, PIM-ID, baureihe, or category.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search text (matches name, PIM-ID, baureihe).",
|
||||
"default": "",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Filter by category key.",
|
||||
"default": "",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (default 20, max 50).",
|
||||
"default": 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_order",
|
||||
"description": "Create a new render order with the given products and output type. Confirm with the user before executing.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of product UUIDs to include.",
|
||||
},
|
||||
"output_type_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the output type (e.g. 'Still Render', 'Turntable').",
|
||||
},
|
||||
"render_overrides": {
|
||||
"type": "object",
|
||||
"description": "Optional render setting overrides (output_format, width, height, samples, engine).",
|
||||
},
|
||||
"material_override": {
|
||||
"type": "string",
|
||||
"description": "Optional SCHAEFFLER library material name to apply to all lines.",
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
"required": ["product_ids", "output_type_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dispatch_renders",
|
||||
"description": "Dispatch (or retry) renders for all pending/failed lines in an order.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the order.",
|
||||
},
|
||||
},
|
||||
"required": ["order_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_order_status",
|
||||
"description": "Get detailed status of a specific order including all lines and render progress.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the order.",
|
||||
},
|
||||
},
|
||||
"required": ["order_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "set_material_override",
|
||||
"description": "Set a material override on all lines of an order. All parts will be rendered with this single material.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the order.",
|
||||
},
|
||||
"material_name": {
|
||||
"type": "string",
|
||||
"description": "SCHAEFFLER library material name, or empty string to clear.",
|
||||
},
|
||||
},
|
||||
"required": ["order_id", "material_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "set_render_overrides",
|
||||
"description": "Set render overrides on all lines of an order (output_format, width, height, samples, engine, etc.).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the order.",
|
||||
},
|
||||
"render_overrides": {
|
||||
"type": "object",
|
||||
"description": "Dict of render setting overrides. Pass null/empty to clear.",
|
||||
},
|
||||
},
|
||||
"required": ["order_id", "render_overrides"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_render_stats",
|
||||
"description": "Get render pipeline statistics: queue status, throughput, product/order counts.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "check_materials",
|
||||
"description": "Check if all materials in an order are mapped to library materials. Returns unmapped materials.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the order to check.",
|
||||
},
|
||||
},
|
||||
"required": ["order_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_database",
|
||||
"description": "Execute a read-only SQL SELECT query against the database. Results are automatically filtered to the current tenant. Tables: orders, order_lines, products, cad_files, materials, material_aliases, output_types, media_assets, render_templates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {
|
||||
"type": "string",
|
||||
"description": "A SELECT SQL query to execute.",
|
||||
},
|
||||
},
|
||||
"required": ["sql"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── Azure OpenAI client helper ───────────────────────────────────────────────
|
||||
|
||||
def _resolve_credentials(tenant_config: dict | None) -> dict:
|
||||
"""Resolve Azure OpenAI credentials from tenant config or global settings."""
|
||||
if tenant_config and tenant_config.get("ai_enabled"):
|
||||
return {
|
||||
"api_key": tenant_config.get("ai_api_key") or settings.azure_openai_api_key,
|
||||
"endpoint": tenant_config.get("ai_endpoint") or settings.azure_openai_endpoint,
|
||||
"deployment": tenant_config.get("ai_deployment") or settings.azure_openai_deployment,
|
||||
"api_version": tenant_config.get("ai_api_version") or settings.azure_openai_api_version,
|
||||
"max_tokens": int(tenant_config.get("ai_max_tokens", 1000)),
|
||||
"temperature": float(tenant_config.get("ai_temperature", 0.3)),
|
||||
}
|
||||
return {
|
||||
"api_key": settings.azure_openai_api_key,
|
||||
"endpoint": settings.azure_openai_endpoint,
|
||||
"deployment": settings.azure_openai_deployment,
|
||||
"api_version": settings.azure_openai_api_version,
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
|
||||
# ── Tool execution (async, tenant-scoped) ────────────────────────────────────
|
||||
|
||||
async def _execute_tool(
|
||||
name: str,
|
||||
arguments: dict,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> str:
|
||||
"""Execute a tool call and return the result as a JSON string."""
|
||||
try:
|
||||
if name == "list_orders":
|
||||
return await _tool_list_orders(db, tenant_id, **arguments)
|
||||
elif name == "search_products":
|
||||
return await _tool_search_products(db, tenant_id, **arguments)
|
||||
elif name == "create_order":
|
||||
return await _tool_create_order(db, tenant_id, user_id, **arguments)
|
||||
elif name == "dispatch_renders":
|
||||
return await _tool_dispatch_renders(db, tenant_id, **arguments)
|
||||
elif name == "get_order_status":
|
||||
return await _tool_get_order_status(db, tenant_id, **arguments)
|
||||
elif name == "set_material_override":
|
||||
return await _tool_set_material_override(db, tenant_id, **arguments)
|
||||
elif name == "set_render_overrides":
|
||||
return await _tool_set_render_overrides(db, tenant_id, **arguments)
|
||||
elif name == "get_render_stats":
|
||||
return await _tool_get_render_stats(db, tenant_id)
|
||||
elif name == "check_materials":
|
||||
return await _tool_check_materials(db, tenant_id, **arguments)
|
||||
elif name == "query_database":
|
||||
return await _tool_query_database(db, tenant_id, **arguments)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown tool: {name}"})
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed: %s(%s)", name, arguments)
|
||||
return json.dumps({"error": str(exc)})
|
||||
|
||||
|
||||
async def _tool_list_orders(db: AsyncSession, tenant_id: str, status: str = "", limit: int = 10) -> str:
|
||||
limit = min(max(limit, 1), 50)
|
||||
sql = """
|
||||
SELECT o.id, o.order_number, o.status, o.created_at,
|
||||
COUNT(ol.id) AS line_count,
|
||||
COUNT(ol.id) FILTER (WHERE ol.render_status = 'completed') AS completed_lines,
|
||||
COUNT(ol.id) FILTER (WHERE ol.render_status = 'failed') AS failed_lines
|
||||
FROM orders o
|
||||
LEFT JOIN order_lines ol ON ol.order_id = o.id
|
||||
WHERE o.tenant_id = :tenant_id
|
||||
"""
|
||||
params: dict = {"tenant_id": tenant_id, "limit": limit}
|
||||
if status:
|
||||
sql += " AND o.status = :status"
|
||||
params["status"] = status
|
||||
sql += " GROUP BY o.id ORDER BY o.created_at DESC LIMIT :limit"
|
||||
result = await db.execute(text(sql), params)
|
||||
rows = result.mappings().all()
|
||||
return json.dumps([dict(r) for r in rows], indent=2, default=str)
|
||||
|
||||
|
||||
async def _tool_search_products(db: AsyncSession, tenant_id: str, query: str = "", category: str = "", limit: int = 20) -> str:
|
||||
limit = min(max(limit, 1), 50)
|
||||
sql = """
|
||||
SELECT p.id, p.name, p.pim_id, p.category_key, p.baureihe,
|
||||
p.cad_file_id IS NOT NULL AS has_step,
|
||||
cf.processing_status
|
||||
FROM products p
|
||||
LEFT JOIN cad_files cf ON cf.id = p.cad_file_id
|
||||
WHERE p.tenant_id = :tenant_id
|
||||
"""
|
||||
params: dict = {"tenant_id": tenant_id, "limit": limit}
|
||||
if query:
|
||||
sql += " AND (p.name ILIKE :q OR p.pim_id ILIKE :q OR p.baureihe ILIKE :q)"
|
||||
params["q"] = f"%{query}%"
|
||||
if category:
|
||||
sql += " AND p.category_key = :category"
|
||||
params["category"] = category
|
||||
sql += " ORDER BY p.name LIMIT :limit"
|
||||
result = await db.execute(text(sql), params)
|
||||
rows = result.mappings().all()
|
||||
return json.dumps([dict(r) for r in rows], indent=2, default=str)
|
||||
|
||||
|
||||
async def _tool_create_order(
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
product_ids: list[str] | None = None,
|
||||
output_type_name: str = "",
|
||||
render_overrides: dict | None = None,
|
||||
material_override: str = "",
|
||||
) -> str:
|
||||
"""Create an order via internal httpx call to the backend API."""
|
||||
import httpx
|
||||
|
||||
if not product_ids:
|
||||
return json.dumps({"error": "product_ids is required"})
|
||||
|
||||
# Resolve output type ID from name
|
||||
ot_id = None
|
||||
if output_type_name:
|
||||
ot_result = await db.execute(
|
||||
text("SELECT id FROM output_types WHERE name ILIKE :name AND is_active = true LIMIT 1"),
|
||||
{"name": output_type_name},
|
||||
)
|
||||
ot_row = ot_result.mappings().first()
|
||||
if ot_row:
|
||||
ot_id = str(ot_row["id"])
|
||||
else:
|
||||
return json.dumps({"error": f"No active output type found matching '{output_type_name}'"})
|
||||
|
||||
lines = []
|
||||
for pid in product_ids:
|
||||
line: dict = {"product_id": pid}
|
||||
if ot_id:
|
||||
line["output_type_id"] = ot_id
|
||||
if render_overrides:
|
||||
line["render_overrides"] = render_overrides
|
||||
if material_override:
|
||||
line["material_override"] = material_override
|
||||
lines.append(line)
|
||||
|
||||
# Call backend API internally using a service token
|
||||
from app.utils.auth import create_access_token
|
||||
token = create_access_token(user_id, "global_admin", tenant_id)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
|
||||
resp = await client.post(
|
||||
"/api/orders",
|
||||
json={"lines": lines},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return json.dumps({
|
||||
"order_id": data["id"],
|
||||
"order_number": data["order_number"],
|
||||
"status": data["status"],
|
||||
"line_count": data.get("line_count", len(lines)),
|
||||
}, indent=2)
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": f"Failed to create order: {exc}"})
|
||||
|
||||
|
||||
async def _tool_dispatch_renders(db: AsyncSession, tenant_id: str, order_id: str = "") -> str:
|
||||
"""Dispatch renders via internal httpx call."""
|
||||
import httpx
|
||||
from app.utils.auth import create_access_token
|
||||
|
||||
# Verify order belongs to tenant
|
||||
check = await db.execute(
|
||||
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
|
||||
{"oid": order_id, "tid": tenant_id},
|
||||
)
|
||||
if not check.first():
|
||||
return json.dumps({"error": "Order not found or not in your tenant"})
|
||||
|
||||
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=60) as client:
|
||||
resp = await client.post(
|
||||
f"/api/orders/{order_id}/dispatch-renders",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return json.dumps(resp.json(), indent=2, default=str)
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": f"Failed to dispatch renders: {exc}"})
|
||||
|
||||
|
||||
async def _tool_get_order_status(db: AsyncSession, tenant_id: str, order_id: str = "") -> str:
|
||||
sql = """
|
||||
SELECT o.id, o.order_number, o.status, o.created_at,
|
||||
json_agg(json_build_object(
|
||||
'line_id', ol.id,
|
||||
'product_name', p.name,
|
||||
'render_status', ol.render_status,
|
||||
'output_type', ot.name,
|
||||
'material_override', ol.material_override
|
||||
)) AS lines
|
||||
FROM orders o
|
||||
LEFT JOIN order_lines ol ON ol.order_id = o.id
|
||||
LEFT JOIN products p ON p.id = ol.product_id
|
||||
LEFT JOIN output_types ot ON ot.id = ol.output_type_id
|
||||
WHERE o.id = :oid AND o.tenant_id = :tid
|
||||
GROUP BY o.id
|
||||
"""
|
||||
result = await db.execute(text(sql), {"oid": order_id, "tid": tenant_id})
|
||||
row = result.mappings().first()
|
||||
if not row:
|
||||
return json.dumps({"error": "Order not found or not in your tenant"})
|
||||
return json.dumps(dict(row), indent=2, default=str)
|
||||
|
||||
|
||||
async def _tool_set_material_override(db: AsyncSession, tenant_id: str, order_id: str = "", material_name: str = "") -> str:
|
||||
"""Set material override via internal httpx call."""
|
||||
import httpx
|
||||
from app.utils.auth import create_access_token
|
||||
|
||||
check = await db.execute(
|
||||
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
|
||||
{"oid": order_id, "tid": tenant_id},
|
||||
)
|
||||
if not check.first():
|
||||
return json.dumps({"error": "Order not found or not in your tenant"})
|
||||
|
||||
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"/api/orders/{order_id}/batch-material-override",
|
||||
json={"material_override": material_name or None},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return json.dumps(resp.json(), indent=2, default=str)
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": f"Failed to set material override: {exc}"})
|
||||
|
||||
|
||||
async def _tool_set_render_overrides(db: AsyncSession, tenant_id: str, order_id: str = "", render_overrides: dict | None = None) -> str:
|
||||
"""Set render overrides via internal httpx call."""
|
||||
import httpx
|
||||
from app.utils.auth import create_access_token
|
||||
|
||||
check = await db.execute(
|
||||
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
|
||||
{"oid": order_id, "tid": tenant_id},
|
||||
)
|
||||
if not check.first():
|
||||
return json.dumps({"error": "Order not found or not in your tenant"})
|
||||
|
||||
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"/api/orders/{order_id}/batch-render-overrides",
|
||||
json={"render_overrides": render_overrides},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return json.dumps(resp.json(), indent=2, default=str)
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": f"Failed to set render overrides: {exc}"})
|
||||
|
||||
|
||||
async def _tool_get_render_stats(db: AsyncSession, tenant_id: str) -> str:
|
||||
sql = """
|
||||
SELECT
|
||||
(SELECT count(*) FROM orders WHERE tenant_id = :tid) AS total_orders,
|
||||
(SELECT count(*) FROM products WHERE tenant_id = :tid) AS total_products,
|
||||
(SELECT count(*) FROM order_lines ol
|
||||
JOIN orders o ON o.id = ol.order_id
|
||||
WHERE o.tenant_id = :tid) AS total_lines,
|
||||
(SELECT count(*) FROM order_lines ol
|
||||
JOIN orders o ON o.id = ol.order_id
|
||||
WHERE o.tenant_id = :tid AND ol.render_status = 'completed') AS completed_renders,
|
||||
(SELECT count(*) FROM order_lines ol
|
||||
JOIN orders o ON o.id = ol.order_id
|
||||
WHERE o.tenant_id = :tid AND ol.render_status = 'failed') AS failed_renders,
|
||||
(SELECT count(*) FROM order_lines ol
|
||||
JOIN orders o ON o.id = ol.order_id
|
||||
WHERE o.tenant_id = :tid AND ol.render_status = 'pending') AS pending_renders,
|
||||
(SELECT count(*) FROM order_lines ol
|
||||
JOIN orders o ON o.id = ol.order_id
|
||||
WHERE o.tenant_id = :tid AND ol.render_status = 'processing') AS active_renders
|
||||
"""
|
||||
result = await db.execute(text(sql), {"tid": tenant_id})
|
||||
row = result.mappings().first()
|
||||
return json.dumps(dict(row) if row else {}, indent=2, default=str)
|
||||
|
||||
|
||||
async def _tool_check_materials(db: AsyncSession, tenant_id: str, order_id: str = "") -> str:
|
||||
"""Check unmapped materials for an order — uses internal API call."""
|
||||
import httpx
|
||||
from app.utils.auth import create_access_token
|
||||
|
||||
check = await db.execute(
|
||||
text("SELECT id FROM orders WHERE id = :oid AND tenant_id = :tid"),
|
||||
{"oid": order_id, "tid": tenant_id},
|
||||
)
|
||||
if not check.first():
|
||||
return json.dumps({"error": "Order not found or not in your tenant"})
|
||||
|
||||
token = create_access_token(str(uuid.UUID(int=0)), "global_admin", tenant_id)
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url="http://localhost:8888", timeout=30) as client:
|
||||
resp = await client.get(
|
||||
f"/api/orders/{order_id}/check-materials",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return json.dumps(resp.json(), indent=2, default=str)
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": f"Failed to check materials: {exc}"})
|
||||
|
||||
|
||||
async def _tool_query_database(db: AsyncSession, tenant_id: str, sql: str = "") -> str:
|
||||
"""Execute a read-only SQL query, tenant-scoped."""
|
||||
sql_upper = sql.strip().upper()
|
||||
if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"):
|
||||
return json.dumps({"error": "Only SELECT/WITH queries are allowed (read-only)."})
|
||||
|
||||
for kw in ("INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"):
|
||||
check = sql_upper.split("--")[0].split("/*")[0]
|
||||
if f" {kw} " in f" {check} ":
|
||||
return json.dumps({"error": f"{kw} statements are not allowed (read-only)."})
|
||||
|
||||
# Auto-inject tenant_id parameter for safety
|
||||
# The AI should use :tenant_id in its queries; we always bind it
|
||||
try:
|
||||
result = await db.execute(text(sql), {"tenant_id": tenant_id})
|
||||
rows = result.mappings().all()
|
||||
if not rows:
|
||||
return "Query returned 0 rows."
|
||||
return json.dumps([dict(r) for r in rows[:100]], indent=2, default=str)
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": f"Query error: {exc}"})
|
||||
|
||||
|
||||
# ── Message persistence ──────────────────────────────────────────────────────
|
||||
|
||||
async def _save_message(
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
context_type: str | None = None,
|
||||
context_id: str | None = None,
|
||||
token_count: int | None = None,
|
||||
) -> dict:
|
||||
"""Persist a chat message and return it as a dict."""
|
||||
msg_id = uuid.uuid4()
|
||||
now = datetime.utcnow()
|
||||
await db.execute(
|
||||
text("""
|
||||
INSERT INTO chat_messages (id, tenant_id, user_id, session_id, role, content,
|
||||
context_type, context_id, token_count, created_at)
|
||||
VALUES (:id, :tenant_id, :user_id, :session_id, :role, :content,
|
||||
:context_type, :context_id, :token_count, :created_at)
|
||||
"""),
|
||||
{
|
||||
"id": str(msg_id),
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"role": role,
|
||||
"content": content,
|
||||
"context_type": context_type,
|
||||
"context_id": context_id,
|
||||
"token_count": token_count,
|
||||
"created_at": now,
|
||||
},
|
||||
)
|
||||
return {
|
||||
"id": str(msg_id),
|
||||
"role": role,
|
||||
"content": content,
|
||||
"context_type": context_type,
|
||||
"context_id": str(context_id) if context_id else None,
|
||||
"token_count": token_count,
|
||||
"created_at": now.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def _load_session_messages(db: AsyncSession, session_id: str, tenant_id: str) -> list[dict]:
|
||||
"""Load conversation history for a session (for context window)."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT role, content FROM chat_messages
|
||||
WHERE session_id = :sid AND tenant_id = :tid
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 50
|
||||
"""),
|
||||
{"sid": session_id, "tid": tenant_id},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [{"role": r["role"], "content": r["content"]} for r in rows]
|
||||
|
||||
|
||||
# ── Main chat function ───────────────────────────────────────────────────────
|
||||
|
||||
async def chat_with_agent(
|
||||
message: str,
|
||||
session_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
tenant_config: dict | None = None,
|
||||
context_type: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Process a user message through the Azure OpenAI agent with function calling.
|
||||
|
||||
Returns {"session_id": str, "user_message": dict, "assistant_message": dict}.
|
||||
"""
|
||||
# Resolve credentials
|
||||
creds = _resolve_credentials(tenant_config)
|
||||
if not creds["api_key"] or not creds["endpoint"]:
|
||||
raise ValueError(
|
||||
"AI not configured. Ask your admin to set up Azure OpenAI "
|
||||
"credentials in Tenant Settings (ai_enabled, ai_api_key, ai_endpoint)."
|
||||
)
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key=creds["api_key"],
|
||||
azure_endpoint=creds["endpoint"],
|
||||
api_version=creds["api_version"],
|
||||
)
|
||||
|
||||
# Build message history
|
||||
history = await _load_session_messages(db, session_id, tenant_id)
|
||||
|
||||
# Build context-aware system prompt
|
||||
system_content = SYSTEM_PROMPT
|
||||
if context_type and context_id:
|
||||
system_content += f"\n\nCurrent context: {context_type} {context_id}"
|
||||
system_content += f"\n\nThe user's tenant_id is '{tenant_id}'. Always filter queries by this tenant_id."
|
||||
|
||||
messages: list[dict] = [{"role": "system", "content": system_content}]
|
||||
messages.extend(history)
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
# Save user message
|
||||
user_msg = await _save_message(
|
||||
db, tenant_id, user_id, session_id, "user", message,
|
||||
context_type=context_type, context_id=context_id,
|
||||
)
|
||||
|
||||
# OpenAI function-calling loop
|
||||
max_iterations = 10
|
||||
iteration = 0
|
||||
total_tokens = 0
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=creds["deployment"],
|
||||
messages=messages,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
max_tokens=creds["max_tokens"],
|
||||
temperature=creds["temperature"],
|
||||
)
|
||||
if response.usage:
|
||||
total_tokens += response.usage.total_tokens
|
||||
|
||||
while response.choices[0].message.tool_calls and iteration < max_iterations:
|
||||
iteration += 1
|
||||
assistant_msg = response.choices[0].message
|
||||
|
||||
# Append assistant message with tool calls to context
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
||||
}
|
||||
for tc in assistant_msg.tool_calls
|
||||
],
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
for tool_call in assistant_msg.tool_calls:
|
||||
fn_name = tool_call.function.name
|
||||
fn_args = json.loads(tool_call.function.arguments)
|
||||
logger.info("Chat tool call: %s(%s) [tenant=%s]", fn_name, fn_args, tenant_id)
|
||||
|
||||
tool_result = await _execute_tool(fn_name, fn_args, tenant_id, user_id, db)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result,
|
||||
})
|
||||
|
||||
# Next LLM call with tool results
|
||||
response = client.chat.completions.create(
|
||||
model=creds["deployment"],
|
||||
messages=messages,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
max_tokens=creds["max_tokens"],
|
||||
temperature=creds["temperature"],
|
||||
)
|
||||
if response.usage:
|
||||
total_tokens += response.usage.total_tokens
|
||||
|
||||
# Extract final text response
|
||||
final_content = response.choices[0].message.content or ""
|
||||
|
||||
# Save assistant response
|
||||
assistant_msg_out = await _save_message(
|
||||
db, tenant_id, user_id, session_id, "assistant", final_content,
|
||||
context_type=context_type, context_id=context_id,
|
||||
token_count=total_tokens,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"user_message": user_msg,
|
||||
"assistant_message": assistant_msg_out,
|
||||
}
|
||||
Reference in New Issue
Block a user