"""Application middleware. TenantContextMiddleware Decodes the JWT Bearer token (if present) from every incoming request and stores tenant_id + role in request.state. The get_db dependency reads request.state to automatically set the RLS context before yielding the session — no endpoint code change required. """ import logging from jose import JWTError, jwt from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from app.config import settings _log = logging.getLogger(__name__) class TenantContextMiddleware(BaseHTTPMiddleware): """Extract JWT → inject tenant_id + role into request.state. Does NOT reject unauthenticated requests — that is still handled by the route-level dependencies (require_admin, get_current_user, etc.). Missing / invalid tokens result in request.state.tenant_id = None. """ async def dispatch(self, request: Request, call_next) -> Response: tenant_id: str | None = None role: str | None = None auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] try: payload = jwt.decode( token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm], ) tenant_id = payload.get("tenant_id") role = payload.get("role") except JWTError: pass # invalid/expired tokens are handled per-endpoint request.state.tenant_id = tenant_id request.state.role = role return await call_next(request)