"""Sync tenant context helpers for Celery tasks. Celery tasks run in a sync context (no async event loop), so they cannot use the async ``set_tenant_context`` from ``app.database``. This module provides ``set_tenant_context_sync`` which accepts a SQLAlchemy sync ``Session`` and a raw ``tenant_id`` UUID string (or None for global-admin bypass), as well as ``resolve_tenant_id_for_cad`` / ``resolve_tenant_id_for_order_line`` helpers that look up the tenant_id from the database given only an entity ID. Typical usage at the start of a Celery task:: from app.core.tenant_context import resolve_tenant_id_for_cad, set_tenant_context_sync tenant_id = resolve_tenant_id_for_cad(cad_file_id) # tenant_id is already logged by resolve_tenant_id_for_cad # Then in every Session block that does RLS-protected queries: with Session(engine) as session: set_tenant_context_sync(session, tenant_id) # ... queries here respect RLS ... """ import logging from typing import Optional from sqlalchemy import create_engine, text from sqlalchemy.orm import Session logger = logging.getLogger(__name__) def set_tenant_context_sync(db: Session, tenant_id: Optional[str]) -> None: """Set the PostgreSQL RLS context variable for a sync SQLAlchemy session. Executes ``SET LOCAL app.current_tenant_id = :tid`` so that all subsequent queries within the same transaction respect row-level security policies. Args: db: An open sync SQLAlchemy ``Session``. tenant_id: UUID string of the tenant, or ``None`` / empty string to use the bypass sentinel (global-admin context — sees all rows). """ if tenant_id: db.execute( text("SET LOCAL app.current_tenant_id = :tid"), {"tid": str(tenant_id)}, ) else: # None means no tenant context is known (e.g. system tasks). # Use empty string — RLS policies treat '' as no-tenant, which allows # global admin queries to proceed without filtering. db.execute(text("SET LOCAL app.current_tenant_id = ''")) def resolve_tenant_id_for_cad(cad_file_id: str) -> Optional[str]: """Look up the tenant_id for a CadFile by its primary key. Opens a short-lived sync session, reads CadFile.tenant_id, and returns it as a string UUID or None. Also emits the ``[TENANT]`` log line. Args: cad_file_id: The UUID string (or UUID) of the CadFile record. Returns: tenant_id as ``str`` if the CadFile has one, ``None`` otherwise. """ try: from app.config import settings as _cfg from app.models.cad_file import CadFile # compat shim → domains.products.models _sync_url = _cfg.database_url.replace("+asyncpg", "") _eng = create_engine(_sync_url) try: with Session(_eng) as _sess: _cad = _sess.get(CadFile, cad_file_id) tenant_id = str(_cad.tenant_id) if (_cad and _cad.tenant_id) else None finally: _eng.dispose() except Exception as exc: logger.warning("[TENANT] resolve_tenant_id_for_cad(%s) failed: %s", cad_file_id, exc) tenant_id = None logger.info("[TENANT] context set: tenant_id=%s", tenant_id) return tenant_id def resolve_tenant_id_for_order_line(order_line_id: str) -> Optional[str]: """Look up the tenant_id for an OrderLine by its primary key. Opens a short-lived sync session, reads OrderLine.tenant_id, and returns it as a string UUID or None. Also emits the ``[TENANT]`` log line. Args: order_line_id: The UUID string (or UUID) of the OrderLine record. Returns: tenant_id as ``str`` if the OrderLine has one, ``None`` otherwise. """ try: from app.config import settings as _cfg from app.models.order_line import OrderLine # compat shim _sync_url = _cfg.database_url.replace("+asyncpg", "") _eng = create_engine(_sync_url) try: with Session(_eng) as _sess: _line = _sess.get(OrderLine, order_line_id) tenant_id = str(_line.tenant_id) if (_line and _line.tenant_id) else None finally: _eng.dispose() except Exception as exc: logger.warning("[TENANT] resolve_tenant_id_for_order_line(%s) failed: %s", order_line_id, exc) tenant_id = None logger.info("[TENANT] context set: tenant_id=%s", tenant_id) return tenant_id