"""WebSocket connection manager with Redis Pub/Sub broadcasting. Architecture: - ConnectionManager holds in-memory mapping: tenant_id -> set[WebSocket] - A background asyncio task subscribes to Redis Pub/Sub channels - Backend tasks/routers call publish_event_sync() (sync, Celery-safe) which does redis.publish(f"tenant:{tenant_id}", json.dumps(event)) - The subscriber loop receives messages and forwards to all WS for that tenant Special channel "__broadcast__" is forwarded to ALL connected clients. """ from __future__ import annotations import asyncio import json import logging from collections import defaultdict from typing import Any from fastapi import WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState logger = logging.getLogger(__name__) class ConnectionManager: def __init__(self) -> None: # tenant_id (str) -> set of active WebSocket connections self._connections: dict[str, set[WebSocket]] = defaultdict(set) self._lock = asyncio.Lock() self._subscriber_task: asyncio.Task | None = None # ── Connection lifecycle ────────────────────────────────────────────────── async def connect(self, ws: WebSocket, tenant_id: str) -> None: await ws.accept() async with self._lock: self._connections[tenant_id].add(ws) logger.debug("WS connected tenant=%s total=%d", tenant_id, self._total()) async def disconnect(self, ws: WebSocket, tenant_id: str) -> None: async with self._lock: self._connections[tenant_id].discard(ws) if not self._connections[tenant_id]: del self._connections[tenant_id] logger.debug("WS disconnected tenant=%s total=%d", tenant_id, self._total()) def _total(self) -> int: return sum(len(s) for s in self._connections.values()) # ── Broadcast ───────────────────────────────────────────────────────────── async def broadcast_to_tenant(self, tenant_id: str, event: dict[str, Any]) -> None: """Send event JSON to all WebSockets for a tenant.""" message = json.dumps(event) dead: list[WebSocket] = [] for ws in list(self._connections.get(tenant_id, set())): try: if ws.client_state == WebSocketState.CONNECTED: await ws.send_text(message) except Exception: dead.append(ws) for ws in dead: await self.disconnect(ws, tenant_id) async def broadcast_all(self, event: dict[str, Any]) -> None: """Send event to ALL connected WebSockets regardless of tenant.""" message = json.dumps(event) dead: list[tuple[WebSocket, str]] = [] async with self._lock: snapshot = {tid: set(sockets) for tid, sockets in self._connections.items()} for tenant_id, sockets in snapshot.items(): for ws in sockets: try: if ws.client_state == WebSocketState.CONNECTED: await ws.send_text(message) except Exception: dead.append((ws, tenant_id)) for ws, tid in dead: await self.disconnect(ws, tid) # ── Redis Pub/Sub subscriber ────────────────────────────────────────────── async def start_redis_subscriber(self) -> None: """Start background task that listens for Redis Pub/Sub messages.""" if self._subscriber_task is not None: return self._subscriber_task = asyncio.create_task(self._subscribe_loop()) logger.info("WebSocket Redis subscriber started") async def _subscribe_loop(self) -> None: from app.config import settings import redis.asyncio as aioredis while True: try: client = aioredis.from_url(settings.redis_url, decode_responses=True) pubsub = client.pubsub() await pubsub.psubscribe("tenant:*", "__broadcast__") logger.info("Subscribed to Redis channels tenant:* and __broadcast__") async for message in pubsub.listen(): if message["type"] not in ("message", "pmessage"): continue channel: str = message.get("channel") or message.get("pattern") or "" data_str: str = message.get("data", "") try: event = json.loads(data_str) except (json.JSONDecodeError, TypeError): continue if channel == "__broadcast__": await self.broadcast_all(event) elif channel.startswith("tenant:"): tenant_id = channel[len("tenant:"):] await self.broadcast_to_tenant(tenant_id, event) except asyncio.CancelledError: break except Exception as exc: logger.warning("Redis subscriber error, reconnecting in 3s: %s", exc) await asyncio.sleep(3) async def stop(self) -> None: if self._subscriber_task: self._subscriber_task.cancel() try: await self._subscriber_task except asyncio.CancelledError: pass self._subscriber_task = None # ── Singleton instance ──────────────────────────────────────────────────────── manager = ConnectionManager() # ── Sync helper for Celery tasks ────────────────────────────────────────────── def publish_event_sync(tenant_id: str, event: dict[str, Any]) -> None: """Publish a WebSocket event from a synchronous Celery task. Uses a plain (sync) Redis client to publish to the Pub/Sub channel. The async subscriber loop in the FastAPI process will forward it to WS clients. """ try: import redis as sync_redis from app.config import settings r = sync_redis.from_url(settings.redis_url, decode_responses=True) channel = f"tenant:{tenant_id}" if tenant_id != "__broadcast__" else "__broadcast__" r.publish(channel, json.dumps(event)) r.close() except Exception as exc: logger.warning("publish_event_sync failed: %s", exc)