"""Database utilities for use inside Celery tasks (sync context). Celery tasks bypass FastAPI middleware, so tenant RLS context must be set manually. Use _set_tenant() immediately after creating a sync DB session. Usage:: from app.core.db_utils import set_tenant_sync, get_sync_session with get_sync_session() as db: set_tenant_sync(db, tenant_id, role) # queries here will be RLS-filtered """ import contextlib from typing import Generator from sqlalchemy import create_engine, text from sqlalchemy.orm import Session, sessionmaker from app.config import settings def set_tenant_sync(session: Session, tenant_id: str | None, role: str | None = None) -> None: """Set RLS tenant context on a *synchronous* SQLAlchemy session. Call this at the very start of any sync DB block inside a Celery task when you need tenant isolation. Admins bypass RLS; all other roles get a scoped context. If tenant_id is None the call is a no-op (global access, i.e. no RLS enforcement). """ if not tenant_id: return if role == "admin": session.execute(text("SET LOCAL app.current_tenant_id = 'bypass'")) else: session.execute( text("SET LOCAL app.current_tenant_id = :tid"), {"tid": tenant_id}, ) # Lazily created sync engine (reused across tasks in the same worker process) _sync_engine = None def _get_sync_engine(): global _sync_engine if _sync_engine is None: sync_url = settings.database_url.replace("+asyncpg", "").replace("+aiosqlite", "") _sync_engine = create_engine(sync_url, pool_pre_ping=True, pool_size=5, max_overflow=10) return _sync_engine @contextlib.contextmanager def get_sync_session(tenant_id: str | None = None, role: str | None = None) -> Generator[Session, None, None]: """Context manager that yields a synchronous DB session with optional RLS. Prefer using the existing async session patterns in FastAPI routes. This helper is intended for Celery tasks only. """ factory = sessionmaker(bind=_get_sync_engine(), expire_on_commit=False) with factory() as session: if tenant_id: set_tenant_sync(session, tenant_id, role) try: yield session except Exception: session.rollback() raise else: session.commit()