Files
HartOMat/backend/app/core/websocket.py
T
Hartmut 7a1329958d feat(J): WebSocket live-events + replace polling + fix ffmpeg turntable timeout
- fix(render): ffmpeg overlay=0:0 -> overlay=0:0:shortest=1 to prevent hang on finite PNG sequences
- feat(ws): add core/websocket.py ConnectionManager + Redis Pub/Sub subscriber loop
- feat(ws): add /api/ws WebSocket endpoint with JWT query-param auth in main.py
- feat(ws): emit render_complete/failed + cad_processing_complete events from step_tasks.py
- feat(ws): emit order_status_change events from orders router
- feat(ws): add beat_tasks.py broadcast_queue_status task (every 10s via Redis __broadcast__)
- feat(frontend): add useWebSocket hook with auto-reconnect (exponential backoff, 25s ping)
- feat(frontend): add WebSocketContext + WebSocketProvider wrapping App
- refactor(frontend): remove polling from WorkerActivity (was 5s/3s) + OrderDetail (was 5s)
- refactor(frontend): reduce polling in Layout (8s->60s) + NotificationCenter (15s->60s)
- docs: add ffmpeg shortest=1 + WebSocket JWT auth learnings to LEARNINGS.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 20:49:34 +01:00

154 lines
6.6 KiB
Python

"""WebSocket connection manager with Redis Pub/Sub broadcasting.
Architecture:
- ConnectionManager holds in-memory mapping: tenant_id -> set[WebSocket]
- A background asyncio task subscribes to Redis Pub/Sub channels
- Backend tasks/routers call publish_event_sync() (sync, Celery-safe)
which does redis.publish(f"tenant:{tenant_id}", json.dumps(event))
- The subscriber loop receives messages and forwards to all WS for that tenant
Special channel "__broadcast__" is forwarded to ALL connected clients.
"""
from __future__ import annotations
import asyncio
import json
import logging
from collections import defaultdict
from typing import Any
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
logger = logging.getLogger(__name__)
class ConnectionManager:
def __init__(self) -> None:
# tenant_id (str) -> set of active WebSocket connections
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
self._lock = asyncio.Lock()
self._subscriber_task: asyncio.Task | None = None
# ── Connection lifecycle ──────────────────────────────────────────────────
async def connect(self, ws: WebSocket, tenant_id: str) -> None:
await ws.accept()
async with self._lock:
self._connections[tenant_id].add(ws)
logger.debug("WS connected tenant=%s total=%d", tenant_id, self._total())
async def disconnect(self, ws: WebSocket, tenant_id: str) -> None:
async with self._lock:
self._connections[tenant_id].discard(ws)
if not self._connections[tenant_id]:
del self._connections[tenant_id]
logger.debug("WS disconnected tenant=%s total=%d", tenant_id, self._total())
def _total(self) -> int:
return sum(len(s) for s in self._connections.values())
# ── Broadcast ─────────────────────────────────────────────────────────────
async def broadcast_to_tenant(self, tenant_id: str, event: dict[str, Any]) -> None:
"""Send event JSON to all WebSockets for a tenant."""
message = json.dumps(event)
dead: list[WebSocket] = []
for ws in list(self._connections.get(tenant_id, set())):
try:
if ws.client_state == WebSocketState.CONNECTED:
await ws.send_text(message)
except Exception:
dead.append(ws)
for ws in dead:
await self.disconnect(ws, tenant_id)
async def broadcast_all(self, event: dict[str, Any]) -> None:
"""Send event to ALL connected WebSockets regardless of tenant."""
message = json.dumps(event)
dead: list[tuple[WebSocket, str]] = []
async with self._lock:
snapshot = {tid: set(sockets) for tid, sockets in self._connections.items()}
for tenant_id, sockets in snapshot.items():
for ws in sockets:
try:
if ws.client_state == WebSocketState.CONNECTED:
await ws.send_text(message)
except Exception:
dead.append((ws, tenant_id))
for ws, tid in dead:
await self.disconnect(ws, tid)
# ── Redis Pub/Sub subscriber ──────────────────────────────────────────────
async def start_redis_subscriber(self) -> None:
"""Start background task that listens for Redis Pub/Sub messages."""
if self._subscriber_task is not None:
return
self._subscriber_task = asyncio.create_task(self._subscribe_loop())
logger.info("WebSocket Redis subscriber started")
async def _subscribe_loop(self) -> None:
from app.config import settings
import redis.asyncio as aioredis
while True:
try:
client = aioredis.from_url(settings.redis_url, decode_responses=True)
pubsub = client.pubsub()
await pubsub.psubscribe("tenant:*", "__broadcast__")
logger.info("Subscribed to Redis channels tenant:* and __broadcast__")
async for message in pubsub.listen():
if message["type"] not in ("message", "pmessage"):
continue
channel: str = message.get("channel") or message.get("pattern") or ""
data_str: str = message.get("data", "")
try:
event = json.loads(data_str)
except (json.JSONDecodeError, TypeError):
continue
if channel == "__broadcast__":
await self.broadcast_all(event)
elif channel.startswith("tenant:"):
tenant_id = channel[len("tenant:"):]
await self.broadcast_to_tenant(tenant_id, event)
except asyncio.CancelledError:
break
except Exception as exc:
logger.warning("Redis subscriber error, reconnecting in 3s: %s", exc)
await asyncio.sleep(3)
async def stop(self) -> None:
if self._subscriber_task:
self._subscriber_task.cancel()
try:
await self._subscriber_task
except asyncio.CancelledError:
pass
self._subscriber_task = None
# ── Singleton instance ────────────────────────────────────────────────────────
manager = ConnectionManager()
# ── Sync helper for Celery tasks ──────────────────────────────────────────────
def publish_event_sync(tenant_id: str, event: dict[str, Any]) -> None:
"""Publish a WebSocket event from a synchronous Celery task.
Uses a plain (sync) Redis client to publish to the Pub/Sub channel.
The async subscriber loop in the FastAPI process will forward it to WS clients.
"""
try:
import redis as sync_redis
from app.config import settings
r = sync_redis.from_url(settings.redis_url, decode_responses=True)
channel = f"tenant:{tenant_id}" if tenant_id != "__broadcast__" else "__broadcast__"
r.publish(channel, json.dumps(event))
r.close()
except Exception as exc:
logger.warning("publish_event_sync failed: %s", exc)