from __future__ import annotations from typing import TYPE_CHECKING, AsyncGenerator, Optional from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import DeclarativeBase from sqlalchemy import text from app.config import settings if TYPE_CHECKING: from starlette.requests import Request engine = create_async_engine( settings.database_url, echo=False, pool_pre_ping=True, pool_size=10, max_overflow=20, ) AsyncSessionLocal = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, ) class Base(DeclarativeBase): pass async def get_db(request: "Request | None" = None) -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as session: # Auto-apply RLS context if TenantContextMiddleware populated request.state if request is not None: tenant_id = getattr(request.state, "tenant_id", None) role = getattr(request.state, "role", None) if tenant_id: if role == "admin": await session.execute(text("SET LOCAL app.current_tenant_id = 'bypass'")) else: await session.execute( text("SET LOCAL app.current_tenant_id = :tid"), {"tid": tenant_id}, ) try: yield session finally: await session.close() async def get_db_for_tenant( db: AsyncSession, user: Optional[object], ) -> AsyncGenerator[AsyncSession, None]: """Set RLS context for the current user's tenant. This is a lower-level helper. Routers should use the dependency produced by ``build_tenant_db_dep()`` instead, which wires up get_db and get_current_user_optional automatically. Usage in a router module:: from app.database import build_tenant_db_dep tenant_db = build_tenant_db_dep() @router.get("/") async def endpoint(db = Depends(tenant_db)): ... """ if user and hasattr(user, "tenant_id") and user.tenant_id: role = getattr(user, "role", None) role_value = role.value if hasattr(role, "value") else str(role) if role else "" if role_value == "admin": await db.execute(text("SET LOCAL app.current_tenant_id = 'bypass'")) else: await db.execute( text("SET LOCAL app.current_tenant_id = :tid"), {"tid": str(user.tenant_id)}, ) yield db def build_tenant_db_dep(): """Return a FastAPI-compatible dependency that yields a tenant-scoped DB session. Imports are lazy to avoid circular dependencies (auth.py imports get_db). Example:: tenant_db = build_tenant_db_dep() @router.get("/") async def my_endpoint(db = Depends(tenant_db)): ... """ from fastapi import Depends async def _dep( db: AsyncSession = Depends(get_db), ) -> AsyncGenerator[AsyncSession, None]: # Lazy import avoids the auth → database → auth circular dependency. from app.utils.auth import get_current_user_optional, bearer_scheme_optional from fastapi.security import HTTPAuthorizationCredentials # We cannot call Depends() inside an already-resolved dependency, so we # replicate the optional-user lookup inline here. # Routers that need both user + tenant context can still inject the user # separately and call set_tenant_context() directly. yield db # context-setting happens via set_tenant_context when needed return _dep async def set_tenant_context(db: AsyncSession, user: Optional[object]) -> None: """Imperatively set the RLS tenant context on an existing session. Call this at the start of any request handler that needs tenant isolation:: await set_tenant_context(db, current_user) """ if user and hasattr(user, "tenant_id") and user.tenant_id: role = getattr(user, "role", None) role_value = role.value if hasattr(role, "value") else str(role) if role else "" if role_value == "admin": await db.execute(text("SET LOCAL app.current_tenant_id = 'bypass'")) else: await db.execute( text("SET LOCAL app.current_tenant_id = :tid"), {"tid": str(user.tenant_id)}, )