228 lines
7.8 KiB
Python
228 lines
7.8 KiB
Python
"""Chat API endpoints for tenant AI agent conversations."""
|
|
import logging
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.utils.auth import get_current_user
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
|
|
|
|
# ── Pydantic schemas ─────────────────────────────────────────────────────────
|
|
|
|
class ChatMessageCreate(BaseModel):
|
|
message: str
|
|
session_id: str | None = None
|
|
context_type: str | None = None
|
|
context_id: str | None = None
|
|
|
|
|
|
class ChatMessageOut(BaseModel):
|
|
id: str
|
|
role: str
|
|
content: str
|
|
context_type: str | None = None
|
|
context_id: str | None = None
|
|
token_count: int | None = None
|
|
created_at: str
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
session_id: str
|
|
user_message: ChatMessageOut
|
|
assistant_message: ChatMessageOut
|
|
|
|
|
|
class ChatSessionSummary(BaseModel):
|
|
session_id: str
|
|
last_message: str
|
|
message_count: int
|
|
created_at: str
|
|
|
|
|
|
# ── Endpoints ─────────────────────────────────────────────────────────────────
|
|
|
|
@router.post("/messages", response_model=ChatResponse)
|
|
async def send_message(
|
|
body: ChatMessageCreate,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Send a message to the AI assistant and get a response.
|
|
|
|
Creates a new session if session_id is not provided.
|
|
Uses the tenant's Azure OpenAI credentials for the LLM call.
|
|
"""
|
|
from app.services.chat_service import chat_with_agent
|
|
|
|
# Load tenant config
|
|
tenant_config = await _get_tenant_config(db, user)
|
|
|
|
session_id = body.session_id or str(uuid.uuid4())
|
|
|
|
# If session_id was provided, verify it belongs to this user
|
|
if body.session_id:
|
|
check = await db.execute(
|
|
text("""
|
|
SELECT 1 FROM chat_messages
|
|
WHERE session_id = :sid AND user_id = :uid
|
|
LIMIT 1
|
|
"""),
|
|
{"sid": session_id, "uid": str(user.id)},
|
|
)
|
|
if not check.first():
|
|
# New session with user-supplied ID is OK; existing session must belong to user
|
|
pass
|
|
|
|
try:
|
|
result = await chat_with_agent(
|
|
message=body.message,
|
|
session_id=session_id,
|
|
tenant_id=str(user.tenant_id),
|
|
user_id=str(user.id),
|
|
db=db,
|
|
tenant_config=tenant_config,
|
|
context_type=body.context_type,
|
|
context_id=body.context_id,
|
|
)
|
|
return result
|
|
except ValueError as exc:
|
|
# AI not configured
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
|
except Exception as exc:
|
|
error_msg = str(exc)
|
|
error_code = None
|
|
# Extract meaningful error from OpenAI exceptions
|
|
if hasattr(exc, 'message'):
|
|
error_msg = exc.message
|
|
if hasattr(exc, 'body') and isinstance(exc.body, dict):
|
|
err = exc.body.get('error', {})
|
|
error_code = err.get('code')
|
|
error_msg = err.get('message', error_msg)
|
|
logger.error("Chat error: %s", error_msg)
|
|
# Content filter violation → return 422 with user-friendly message
|
|
if error_code == 'content_filter':
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail="Deine Nachricht wurde vom Azure Content Filter blockiert. Bitte formuliere sie um.",
|
|
)
|
|
raise HTTPException(status_code=500, detail=f"AI error: {error_msg[:500]}")
|
|
except Exception as exc:
|
|
logger.exception("Chat error for user %s", user.id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Chat service error: {exc}",
|
|
)
|
|
|
|
|
|
@router.get("/sessions", response_model=list[ChatSessionSummary])
|
|
async def list_sessions(
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""List the current user's chat sessions, most recent first."""
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT
|
|
session_id::text AS session_id,
|
|
(SELECT content FROM chat_messages cm2
|
|
WHERE cm2.session_id = cm.session_id
|
|
ORDER BY cm2.created_at DESC LIMIT 1) AS last_message,
|
|
COUNT(*) AS message_count,
|
|
MIN(cm.created_at) AS created_at
|
|
FROM chat_messages cm
|
|
WHERE cm.user_id = :uid AND cm.tenant_id = :tid
|
|
GROUP BY cm.session_id
|
|
ORDER BY MAX(cm.created_at) DESC
|
|
LIMIT 50
|
|
"""),
|
|
{"uid": str(user.id), "tid": str(user.tenant_id)},
|
|
)
|
|
rows = result.mappings().all()
|
|
return [
|
|
{
|
|
"session_id": r["session_id"],
|
|
"last_message": (r["last_message"] or "")[:200],
|
|
"message_count": r["message_count"],
|
|
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
@router.get("/sessions/{session_id}/messages", response_model=list[ChatMessageOut])
|
|
async def get_session_messages(
|
|
session_id: str,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Get all messages in a chat session."""
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT id::text, role, content, context_type,
|
|
context_id::text, token_count, created_at
|
|
FROM chat_messages
|
|
WHERE session_id = :sid AND user_id = :uid AND tenant_id = :tid
|
|
ORDER BY created_at ASC
|
|
"""),
|
|
{"sid": session_id, "uid": str(user.id), "tid": str(user.tenant_id)},
|
|
)
|
|
rows = result.mappings().all()
|
|
if not rows:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return [
|
|
{
|
|
"id": r["id"],
|
|
"role": r["role"],
|
|
"content": r["content"],
|
|
"context_type": r["context_type"],
|
|
"context_id": r["context_id"],
|
|
"token_count": r["token_count"],
|
|
"created_at": r["created_at"].isoformat() if r["created_at"] else "",
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def delete_session(
|
|
session_id: str,
|
|
user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Delete all messages in a chat session."""
|
|
result = await db.execute(
|
|
text("""
|
|
DELETE FROM chat_messages
|
|
WHERE session_id = :sid AND user_id = :uid AND tenant_id = :tid
|
|
"""),
|
|
{"sid": session_id, "uid": str(user.id), "tid": str(user.tenant_id)},
|
|
)
|
|
await db.commit()
|
|
deleted = result.rowcount
|
|
if deleted == 0:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return {"deleted": deleted, "session_id": session_id}
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _get_tenant_config(db: AsyncSession, user: User) -> dict | None:
|
|
"""Load tenant_config JSONB for the user's tenant."""
|
|
if not user.tenant_id:
|
|
return None
|
|
result = await db.execute(
|
|
text("SELECT tenant_config FROM tenants WHERE id = :tid"),
|
|
{"tid": str(user.tenant_id)},
|
|
)
|
|
row = result.mappings().first()
|
|
return row["tenant_config"] if row else None
|