feat(J): WebSocket live-events + replace polling + fix ffmpeg turntable timeout
- fix(render): ffmpeg overlay=0:0 -> overlay=0:0:shortest=1 to prevent hang on finite PNG sequences - feat(ws): add core/websocket.py ConnectionManager + Redis Pub/Sub subscriber loop - feat(ws): add /api/ws WebSocket endpoint with JWT query-param auth in main.py - feat(ws): emit render_complete/failed + cad_processing_complete events from step_tasks.py - feat(ws): emit order_status_change events from orders router - feat(ws): add beat_tasks.py broadcast_queue_status task (every 10s via Redis __broadcast__) - feat(frontend): add useWebSocket hook with auto-reconnect (exponential backoff, 25s ping) - feat(frontend): add WebSocketContext + WebSocketProvider wrapping App - refactor(frontend): remove polling from WorkerActivity (was 5s/3s) + OrderDetail (was 5s) - refactor(frontend): reduce polling in Layout (8s->60s) + NotificationCenter (15s->60s) - docs: add ffmpeg shortest=1 + WebSocket JWT auth learnings to LEARNINGS.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -486,6 +486,20 @@ async def submit_order(
|
||||
from app.services.pricing_service import refresh_order_price
|
||||
await refresh_order_price(db, order.id)
|
||||
await db.refresh(order)
|
||||
|
||||
# Broadcast WebSocket event for live UI updates
|
||||
try:
|
||||
from app.core.websocket import manager as _ws_mgr
|
||||
_tid = str(user.tenant_id) if user.tenant_id else None
|
||||
if _tid:
|
||||
await _ws_mgr.broadcast_to_tenant(_tid, {
|
||||
"type": "order_status_change",
|
||||
"order_id": str(order.id),
|
||||
"status": "submitted",
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return order
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
"""WebSocket connection manager with Redis Pub/Sub broadcasting.
|
||||
|
||||
Architecture:
|
||||
- ConnectionManager holds in-memory mapping: tenant_id -> set[WebSocket]
|
||||
- A background asyncio task subscribes to Redis Pub/Sub channels
|
||||
- Backend tasks/routers call publish_event_sync() (sync, Celery-safe)
|
||||
which does redis.publish(f"tenant:{tenant_id}", json.dumps(event))
|
||||
- The subscriber loop receives messages and forwards to all WS for that tenant
|
||||
|
||||
Special channel "__broadcast__" is forwarded to ALL connected clients.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
# tenant_id (str) -> set of active WebSocket connections
|
||||
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
|
||||
self._lock = asyncio.Lock()
|
||||
self._subscriber_task: asyncio.Task | None = None
|
||||
|
||||
# ── Connection lifecycle ──────────────────────────────────────────────────
|
||||
|
||||
async def connect(self, ws: WebSocket, tenant_id: str) -> None:
|
||||
await ws.accept()
|
||||
async with self._lock:
|
||||
self._connections[tenant_id].add(ws)
|
||||
logger.debug("WS connected tenant=%s total=%d", tenant_id, self._total())
|
||||
|
||||
async def disconnect(self, ws: WebSocket, tenant_id: str) -> None:
|
||||
async with self._lock:
|
||||
self._connections[tenant_id].discard(ws)
|
||||
if not self._connections[tenant_id]:
|
||||
del self._connections[tenant_id]
|
||||
logger.debug("WS disconnected tenant=%s total=%d", tenant_id, self._total())
|
||||
|
||||
def _total(self) -> int:
|
||||
return sum(len(s) for s in self._connections.values())
|
||||
|
||||
# ── Broadcast ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def broadcast_to_tenant(self, tenant_id: str, event: dict[str, Any]) -> None:
|
||||
"""Send event JSON to all WebSockets for a tenant."""
|
||||
message = json.dumps(event)
|
||||
dead: list[WebSocket] = []
|
||||
for ws in list(self._connections.get(tenant_id, set())):
|
||||
try:
|
||||
if ws.client_state == WebSocketState.CONNECTED:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
await self.disconnect(ws, tenant_id)
|
||||
|
||||
async def broadcast_all(self, event: dict[str, Any]) -> None:
|
||||
"""Send event to ALL connected WebSockets regardless of tenant."""
|
||||
message = json.dumps(event)
|
||||
dead: list[tuple[WebSocket, str]] = []
|
||||
async with self._lock:
|
||||
snapshot = {tid: set(sockets) for tid, sockets in self._connections.items()}
|
||||
for tenant_id, sockets in snapshot.items():
|
||||
for ws in sockets:
|
||||
try:
|
||||
if ws.client_state == WebSocketState.CONNECTED:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
dead.append((ws, tenant_id))
|
||||
for ws, tid in dead:
|
||||
await self.disconnect(ws, tid)
|
||||
|
||||
# ── Redis Pub/Sub subscriber ──────────────────────────────────────────────
|
||||
|
||||
async def start_redis_subscriber(self) -> None:
|
||||
"""Start background task that listens for Redis Pub/Sub messages."""
|
||||
if self._subscriber_task is not None:
|
||||
return
|
||||
self._subscriber_task = asyncio.create_task(self._subscribe_loop())
|
||||
logger.info("WebSocket Redis subscriber started")
|
||||
|
||||
async def _subscribe_loop(self) -> None:
|
||||
from app.config import settings
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
while True:
|
||||
try:
|
||||
client = aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||
pubsub = client.pubsub()
|
||||
await pubsub.psubscribe("tenant:*", "__broadcast__")
|
||||
logger.info("Subscribed to Redis channels tenant:* and __broadcast__")
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] not in ("message", "pmessage"):
|
||||
continue
|
||||
channel: str = message.get("channel") or message.get("pattern") or ""
|
||||
data_str: str = message.get("data", "")
|
||||
try:
|
||||
event = json.loads(data_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
if channel == "__broadcast__":
|
||||
await self.broadcast_all(event)
|
||||
elif channel.startswith("tenant:"):
|
||||
tenant_id = channel[len("tenant:"):]
|
||||
await self.broadcast_to_tenant(tenant_id, event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.warning("Redis subscriber error, reconnecting in 3s: %s", exc)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._subscriber_task:
|
||||
self._subscriber_task.cancel()
|
||||
try:
|
||||
await self._subscriber_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._subscriber_task = None
|
||||
|
||||
|
||||
# ── Singleton instance ────────────────────────────────────────────────────────
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
# ── Sync helper for Celery tasks ──────────────────────────────────────────────
|
||||
|
||||
def publish_event_sync(tenant_id: str, event: dict[str, Any]) -> None:
|
||||
"""Publish a WebSocket event from a synchronous Celery task.
|
||||
|
||||
Uses a plain (sync) Redis client to publish to the Pub/Sub channel.
|
||||
The async subscriber loop in the FastAPI process will forward it to WS clients.
|
||||
"""
|
||||
try:
|
||||
import redis as sync_redis
|
||||
from app.config import settings
|
||||
r = sync_redis.from_url(settings.redis_url, decode_responses=True)
|
||||
channel = f"tenant:{tenant_id}" if tenant_id != "__broadcast__" else "__broadcast__"
|
||||
r.publish(channel, json.dumps(event))
|
||||
r.close()
|
||||
except Exception as exc:
|
||||
logger.warning("publish_event_sync failed: %s", exc)
|
||||
+53
-1
@@ -1,11 +1,13 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
import uuid
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
from app.database import engine, Base
|
||||
from app.core.websocket import manager as ws_manager
|
||||
|
||||
# Import routers from domain locations
|
||||
from app.domains.auth.router import router as auth_router
|
||||
@@ -27,7 +29,10 @@ async def lifespan(app: FastAPI):
|
||||
# Create upload directories
|
||||
for subdir in ("step_files", "excel_files", "thumbnails", "renders", "blend-templates"):
|
||||
Path(settings.upload_dir, subdir).mkdir(parents=True, exist_ok=True)
|
||||
# Start WebSocket Redis subscriber
|
||||
await ws_manager.start_redis_subscriber()
|
||||
yield
|
||||
await ws_manager.stop()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
@@ -86,3 +91,50 @@ app.include_router(media_router)
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "service": "schaefflerautomat-backend"}
|
||||
|
||||
|
||||
@app.websocket("/api/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="JWT access token"),
|
||||
):
|
||||
"""WebSocket endpoint for real-time events.
|
||||
|
||||
Clients connect with ?token=<jwt>. Events are scoped by tenant_id.
|
||||
"""
|
||||
from app.utils.auth import decode_token
|
||||
from app.database import AsyncSessionLocal
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
|
||||
# Authenticate via token query param (WS cannot send Authorization header)
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
except HTTPException:
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
|
||||
# Load user to get tenant_id
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(select(User).where(User.id == uuid.UUID(user_id)))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
await websocket.close(code=4001)
|
||||
return
|
||||
|
||||
tenant_id = str(user.tenant_id) if user.tenant_id else user_id
|
||||
|
||||
await ws_manager.connect(websocket, tenant_id)
|
||||
try:
|
||||
while True:
|
||||
# Keep alive — clients send periodic pings as text
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
await ws_manager.disconnect(websocket, tenant_id)
|
||||
except Exception:
|
||||
await ws_manager.disconnect(websocket, tenant_id)
|
||||
|
||||
@@ -504,7 +504,7 @@ def render_turntable_to_file(
|
||||
"-framerate", str(fps),
|
||||
"-i", str(frames_dir / "frame_%04d.png"),
|
||||
"-f", "lavfi", "-i", f"color=c=0x{hex_color}:size={width}x{height}:rate={fps}",
|
||||
"-filter_complex", "[1:v][0:v]overlay=0:0",
|
||||
"-filter_complex", "[1:v][0:v]overlay=0:0:shortest=1",
|
||||
"-vcodec", "libx264",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-crf", "18",
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Celery Beat periodic tasks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(name="app.tasks.beat_tasks.broadcast_queue_status", queue="step_processing")
|
||||
def broadcast_queue_status() -> None:
|
||||
"""Broadcast current queue depths to all WebSocket clients every 10s.
|
||||
|
||||
Publishes to the Redis '__broadcast__' channel which the WebSocket
|
||||
subscriber in the FastAPI process forwards to all connected clients.
|
||||
"""
|
||||
try:
|
||||
import redis as sync_redis
|
||||
from app.config import settings
|
||||
|
||||
r = sync_redis.from_url(settings.redis_url, decode_responses=True)
|
||||
depths = {
|
||||
"step_processing": r.llen("step_processing"),
|
||||
"thumbnail_rendering": r.llen("thumbnail_rendering"),
|
||||
}
|
||||
event = {"type": "queue_update", "depths": depths}
|
||||
r.publish("__broadcast__", json.dumps(event))
|
||||
r.close()
|
||||
logger.debug("Broadcast queue_update: %s", depths)
|
||||
except Exception as exc:
|
||||
logger.warning("broadcast_queue_status failed: %s", exc)
|
||||
@@ -1,4 +1,5 @@
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
from app.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
@@ -8,6 +9,7 @@ celery_app = Celery(
|
||||
include=[
|
||||
"app.tasks.step_tasks",
|
||||
"app.tasks.ai_tasks",
|
||||
"app.tasks.beat_tasks",
|
||||
"app.domains.rendering.tasks",
|
||||
"app.domains.products.tasks",
|
||||
"app.domains.imports.tasks",
|
||||
@@ -23,7 +25,13 @@ celery_app.conf.update(
|
||||
task_routes={
|
||||
"app.tasks.step_tasks.*": {"queue": "step_processing"},
|
||||
"app.tasks.ai_tasks.*": {"queue": "ai_validation"},
|
||||
"app.tasks.beat_tasks.*": {"queue": "step_processing"},
|
||||
"app.domains.rendering.tasks.*": {"queue": "thumbnail_rendering"},
|
||||
},
|
||||
beat_schedule={},
|
||||
beat_schedule={
|
||||
"broadcast-queue-status-every-10s": {
|
||||
"task": "app.tasks.beat_tasks.broadcast_queue_status",
|
||||
"schedule": 10.0, # every 10 seconds
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -171,6 +171,28 @@ def render_step_thumbnail(self, cad_file_id: str):
|
||||
f"Auto material population failed for cad_file {cad_file_id} (non-fatal)"
|
||||
)
|
||||
|
||||
# Broadcast WebSocket event for live UI updates
|
||||
try:
|
||||
from sqlalchemy import create_engine, select as sql_select2
|
||||
from sqlalchemy.orm import Session as _Session
|
||||
from app.config import settings as _cfg
|
||||
from app.models.cad_file import CadFile as _CadFile
|
||||
_sync_url = _cfg.database_url.replace("+asyncpg", "")
|
||||
_eng = create_engine(_sync_url)
|
||||
with _Session(_eng) as _s:
|
||||
_cad = _s.get(_CadFile, cad_file_id)
|
||||
_tid = str(_cad.tenant_id) if _cad and _cad.tenant_id else None
|
||||
_eng.dispose()
|
||||
if _tid:
|
||||
from app.core.websocket import publish_event_sync
|
||||
publish_event_sync(_tid, {
|
||||
"type": "cad_processing_complete",
|
||||
"cad_file_id": cad_file_id,
|
||||
"status": "completed",
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("WebSocket publish for CAD complete skipped (non-fatal)")
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.step_tasks.generate_stl_cache", queue="thumbnail_rendering")
|
||||
def generate_stl_cache(self, cad_file_id: str, quality: str):
|
||||
@@ -559,6 +581,22 @@ def render_order_line_task(self, order_line_id: str):
|
||||
else:
|
||||
emit(order_line_id, f"Render failed after {elapsed:.1f}s", "error")
|
||||
|
||||
# Broadcast WebSocket event for live UI updates
|
||||
try:
|
||||
from app.core.websocket import publish_event_sync
|
||||
_tenant_id = str(line.product.cad_file.tenant_id) if (
|
||||
line.product and line.product.cad_file and line.product.cad_file.tenant_id
|
||||
) else None
|
||||
if _tenant_id:
|
||||
publish_event_sync(_tenant_id, {
|
||||
"type": "render_complete" if success else "render_failed",
|
||||
"order_line_id": order_line_id,
|
||||
"order_id": str(line.order_id),
|
||||
"status": new_status,
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("WebSocket publish skipped (non-fatal)")
|
||||
|
||||
# Notify order creator about render result
|
||||
try:
|
||||
from app.models.order import Order as OrderModel
|
||||
|
||||
Reference in New Issue
Block a user