7a1329958d
- fix(render): ffmpeg overlay=0:0 -> overlay=0:0:shortest=1 to prevent hang on finite PNG sequences - feat(ws): add core/websocket.py ConnectionManager + Redis Pub/Sub subscriber loop - feat(ws): add /api/ws WebSocket endpoint with JWT query-param auth in main.py - feat(ws): emit render_complete/failed + cad_processing_complete events from step_tasks.py - feat(ws): emit order_status_change events from orders router - feat(ws): add beat_tasks.py broadcast_queue_status task (every 10s via Redis __broadcast__) - feat(frontend): add useWebSocket hook with auto-reconnect (exponential backoff, 25s ping) - feat(frontend): add WebSocketContext + WebSocketProvider wrapping App - refactor(frontend): remove polling from WorkerActivity (was 5s/3s) + OrderDetail (was 5s) - refactor(frontend): reduce polling in Layout (8s->60s) + NotificationCenter (15s->60s) - docs: add ffmpeg shortest=1 + WebSocket JWT auth learnings to LEARNINGS.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
154 lines
6.6 KiB
Python
154 lines
6.6 KiB
Python
"""WebSocket connection manager with Redis Pub/Sub broadcasting.
|
|
|
|
Architecture:
|
|
- ConnectionManager holds in-memory mapping: tenant_id -> set[WebSocket]
|
|
- A background asyncio task subscribes to Redis Pub/Sub channels
|
|
- Backend tasks/routers call publish_event_sync() (sync, Celery-safe)
|
|
which does redis.publish(f"tenant:{tenant_id}", json.dumps(event))
|
|
- The subscriber loop receives messages and forwards to all WS for that tenant
|
|
|
|
Special channel "__broadcast__" is forwarded to ALL connected clients.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import Any
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from starlette.websockets import WebSocketState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self) -> None:
|
|
# tenant_id (str) -> set of active WebSocket connections
|
|
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
|
|
self._lock = asyncio.Lock()
|
|
self._subscriber_task: asyncio.Task | None = None
|
|
|
|
# ── Connection lifecycle ──────────────────────────────────────────────────
|
|
|
|
async def connect(self, ws: WebSocket, tenant_id: str) -> None:
|
|
await ws.accept()
|
|
async with self._lock:
|
|
self._connections[tenant_id].add(ws)
|
|
logger.debug("WS connected tenant=%s total=%d", tenant_id, self._total())
|
|
|
|
async def disconnect(self, ws: WebSocket, tenant_id: str) -> None:
|
|
async with self._lock:
|
|
self._connections[tenant_id].discard(ws)
|
|
if not self._connections[tenant_id]:
|
|
del self._connections[tenant_id]
|
|
logger.debug("WS disconnected tenant=%s total=%d", tenant_id, self._total())
|
|
|
|
def _total(self) -> int:
|
|
return sum(len(s) for s in self._connections.values())
|
|
|
|
# ── Broadcast ─────────────────────────────────────────────────────────────
|
|
|
|
async def broadcast_to_tenant(self, tenant_id: str, event: dict[str, Any]) -> None:
|
|
"""Send event JSON to all WebSockets for a tenant."""
|
|
message = json.dumps(event)
|
|
dead: list[WebSocket] = []
|
|
for ws in list(self._connections.get(tenant_id, set())):
|
|
try:
|
|
if ws.client_state == WebSocketState.CONNECTED:
|
|
await ws.send_text(message)
|
|
except Exception:
|
|
dead.append(ws)
|
|
for ws in dead:
|
|
await self.disconnect(ws, tenant_id)
|
|
|
|
async def broadcast_all(self, event: dict[str, Any]) -> None:
|
|
"""Send event to ALL connected WebSockets regardless of tenant."""
|
|
message = json.dumps(event)
|
|
dead: list[tuple[WebSocket, str]] = []
|
|
async with self._lock:
|
|
snapshot = {tid: set(sockets) for tid, sockets in self._connections.items()}
|
|
for tenant_id, sockets in snapshot.items():
|
|
for ws in sockets:
|
|
try:
|
|
if ws.client_state == WebSocketState.CONNECTED:
|
|
await ws.send_text(message)
|
|
except Exception:
|
|
dead.append((ws, tenant_id))
|
|
for ws, tid in dead:
|
|
await self.disconnect(ws, tid)
|
|
|
|
# ── Redis Pub/Sub subscriber ──────────────────────────────────────────────
|
|
|
|
async def start_redis_subscriber(self) -> None:
|
|
"""Start background task that listens for Redis Pub/Sub messages."""
|
|
if self._subscriber_task is not None:
|
|
return
|
|
self._subscriber_task = asyncio.create_task(self._subscribe_loop())
|
|
logger.info("WebSocket Redis subscriber started")
|
|
|
|
async def _subscribe_loop(self) -> None:
|
|
from app.config import settings
|
|
import redis.asyncio as aioredis
|
|
|
|
while True:
|
|
try:
|
|
client = aioredis.from_url(settings.redis_url, decode_responses=True)
|
|
pubsub = client.pubsub()
|
|
await pubsub.psubscribe("tenant:*", "__broadcast__")
|
|
logger.info("Subscribed to Redis channels tenant:* and __broadcast__")
|
|
async for message in pubsub.listen():
|
|
if message["type"] not in ("message", "pmessage"):
|
|
continue
|
|
channel: str = message.get("channel") or message.get("pattern") or ""
|
|
data_str: str = message.get("data", "")
|
|
try:
|
|
event = json.loads(data_str)
|
|
except (json.JSONDecodeError, TypeError):
|
|
continue
|
|
|
|
if channel == "__broadcast__":
|
|
await self.broadcast_all(event)
|
|
elif channel.startswith("tenant:"):
|
|
tenant_id = channel[len("tenant:"):]
|
|
await self.broadcast_to_tenant(tenant_id, event)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as exc:
|
|
logger.warning("Redis subscriber error, reconnecting in 3s: %s", exc)
|
|
await asyncio.sleep(3)
|
|
|
|
async def stop(self) -> None:
|
|
if self._subscriber_task:
|
|
self._subscriber_task.cancel()
|
|
try:
|
|
await self._subscriber_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._subscriber_task = None
|
|
|
|
|
|
# ── Singleton instance ────────────────────────────────────────────────────────
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
# ── Sync helper for Celery tasks ──────────────────────────────────────────────
|
|
|
|
def publish_event_sync(tenant_id: str, event: dict[str, Any]) -> None:
|
|
"""Publish a WebSocket event from a synchronous Celery task.
|
|
|
|
Uses a plain (sync) Redis client to publish to the Pub/Sub channel.
|
|
The async subscriber loop in the FastAPI process will forward it to WS clients.
|
|
"""
|
|
try:
|
|
import redis as sync_redis
|
|
from app.config import settings
|
|
r = sync_redis.from_url(settings.redis_url, decode_responses=True)
|
|
channel = f"tenant:{tenant_id}" if tenant_id != "__broadcast__" else "__broadcast__"
|
|
r.publish(channel, json.dumps(event))
|
|
r.close()
|
|
except Exception as exc:
|
|
logger.warning("publish_event_sync failed: %s", exc)
|