feat(refactor/phase1): foundation infrastructure for modular pipeline

Phase 1 of PLAN_REFACTOR.md — all four sub-tasks implemented:

1.1 PipelineLogger (backend/app/core/pipeline_logger.py)
  - Structured step_start/step_done/step_error/step_progress API
  - Publishes to Python logging AND Redis SSE via log_task_event
  - Context manager `pl.step("name")` for auto-timing

1.2 RenderJobDocument (backend/app/domains/rendering/job_document.py)
  - Pydantic JSONB schema: state machine + per-step records + timing
  - begin_step/finish_step/fail_step/skip_step helpers
  - Migration 048: adds render_job_doc JSONB column to order_lines
  - OrderLine model updated with render_job_doc field

1.3 TenantContextMiddleware (backend/app/core/middleware.py)
  - Decodes JWT, stores tenant_id + role in request.state
  - get_db updated to auto-apply RLS SET LOCAL from request.state
  - Registered in main.py (runs before every request)
  - JWT now embeds tenant_id claim via create_access_token()
  - Login endpoint passes tenant_id to token creation

1.4 ProcessStep Registry (backend/app/core/process_steps.py)
  - StepName StrEnum with all 20 pipeline step names
  - Single source of truth for log prefixes, DB records, UI labels

Also adds db_utils.py with set_tenant_sync() + get_sync_session()
for use inside Celery tasks (bypass-safe RLS helper).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-08 19:25:08 +01:00
parent ee6eb34b4c
commit ea31ed657c
12 changed files with 1654 additions and 5 deletions
+1174
View File
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,27 @@
"""Add render_job_doc JSONB column to order_lines.
Stores a structured RenderJobDocument (state machine, per-step results,
timing, GPU info) alongside the legacy render_log column.
Revision ID: 048
Revises: 047
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
revision = "048"
down_revision = "047"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"order_lines",
sa.Column("render_job_doc", JSONB, nullable=True),
)
def downgrade() -> None:
op.drop_column("order_lines", "render_job_doc")
+2 -1
View File
@@ -38,7 +38,8 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
if not user.is_active:
raise HTTPException(status_code=403, detail="Account disabled")
token = create_access_token(str(user.id), user.role.value)
tenant_id = str(user.tenant_id) if user.tenant_id else None
token = create_access_token(str(user.id), user.role.value, tenant_id=tenant_id)
return TokenResponse(access_token=token, user=UserOut.model_validate(user))
+71
View File
@@ -0,0 +1,71 @@
"""Database utilities for use inside Celery tasks (sync context).
Celery tasks bypass FastAPI middleware, so tenant RLS context must be set
manually. Use _set_tenant() immediately after creating a sync DB session.
Usage::
from app.core.db_utils import set_tenant_sync, get_sync_session
with get_sync_session() as db:
set_tenant_sync(db, tenant_id, role)
# queries here will be RLS-filtered
"""
import contextlib
from typing import Generator
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session, sessionmaker
from app.config import settings
def set_tenant_sync(session: Session, tenant_id: str | None, role: str | None = None) -> None:
"""Set RLS tenant context on a *synchronous* SQLAlchemy session.
Call this at the very start of any sync DB block inside a Celery task
when you need tenant isolation. Admins bypass RLS; all other roles get
a scoped context. If tenant_id is None the call is a no-op (global
access, i.e. no RLS enforcement).
"""
if not tenant_id:
return
if role == "admin":
session.execute(text("SET LOCAL app.current_tenant_id = 'bypass'"))
else:
session.execute(
text("SET LOCAL app.current_tenant_id = :tid"),
{"tid": tenant_id},
)
# Lazily created sync engine (reused across tasks in the same worker process)
_sync_engine = None
def _get_sync_engine():
global _sync_engine
if _sync_engine is None:
sync_url = settings.database_url.replace("+asyncpg", "").replace("+aiosqlite", "")
_sync_engine = create_engine(sync_url, pool_pre_ping=True, pool_size=5, max_overflow=10)
return _sync_engine
@contextlib.contextmanager
def get_sync_session(tenant_id: str | None = None, role: str | None = None) -> Generator[Session, None, None]:
"""Context manager that yields a synchronous DB session with optional RLS.
Prefer using the existing async session patterns in FastAPI routes.
This helper is intended for Celery tasks only.
"""
factory = sessionmaker(bind=_get_sync_engine(), expire_on_commit=False)
with factory() as session:
if tenant_id:
set_tenant_sync(session, tenant_id, role)
try:
yield session
except Exception:
session.rollback()
raise
else:
session.commit()
+49
View File
@@ -0,0 +1,49 @@
"""Application middleware.
TenantContextMiddleware
Decodes the JWT Bearer token (if present) from every incoming request and
stores tenant_id + role in request.state. The get_db dependency reads
request.state to automatically set the RLS context before yielding the
session — no endpoint code change required.
"""
import logging
from jose import JWTError, jwt
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from app.config import settings
_log = logging.getLogger(__name__)
class TenantContextMiddleware(BaseHTTPMiddleware):
"""Extract JWT → inject tenant_id + role into request.state.
Does NOT reject unauthenticated requests — that is still handled by the
route-level dependencies (require_admin, get_current_user, etc.).
Missing / invalid tokens result in request.state.tenant_id = None.
"""
async def dispatch(self, request: Request, call_next) -> Response:
tenant_id: str | None = None
role: str | None = None
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload = jwt.decode(
token,
settings.jwt_secret_key,
algorithms=[settings.jwt_algorithm],
)
tenant_id = payload.get("tenant_id")
role = payload.get("role")
except JWTError:
pass # invalid/expired tokens are handled per-endpoint
request.state.tenant_id = tenant_id
request.state.role = role
return await call_next(request)
+106
View File
@@ -0,0 +1,106 @@
"""Structured pipeline logger.
Wraps Python logging + Redis SSE streaming for consistent, prefixed log output
from all Celery pipeline tasks. Every method:
- emits a Python `logging` line with a [STEP_NAME] prefix
- publishes to Redis via log_task_event for SSE streaming in the UI
"""
import logging
import time
from typing import Any
from app.core.task_logs import log_task_event
_log = logging.getLogger(__name__)
class PipelineLogger:
"""Structured logger for a single pipeline execution context.
Usage in a Celery task::
pl = PipelineLogger(task_id=self.request.id, order_line_id=str(line.id))
pl.step_start("occ_glb_export", {"cad_file_id": cad_file_id})
...
pl.step_done("occ_glb_export", duration_s=8.4, result={"size_bytes": 204800})
"""
def __init__(self, task_id: str | None, order_line_id: str | None = None):
self.task_id = task_id or "unknown"
self.order_line_id = order_line_id
self._step_starts: dict[str, float] = {}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def step_start(self, step: str, context: dict[str, Any] | None = None) -> None:
self._step_starts[step] = time.time()
msg = f"[{step}] start"
if context:
msg += f" | {context}"
_log.info(msg)
log_task_event(self.task_id, msg, level="info")
def step_progress(self, step: str, pct: int, msg: str) -> None:
full = f"[{step}] {pct}% — {msg}"
_log.info(full)
log_task_event(self.task_id, full, level="info")
def step_done(self, step: str, duration_s: float | None = None, result: dict[str, Any] | None = None) -> None:
if duration_s is None:
start = self._step_starts.get(step)
duration_s = round(time.time() - start, 2) if start else None
parts = [f"[{step}] done"]
if duration_s is not None:
parts.append(f"{duration_s:.1f}s")
if result:
parts.append(str(result))
msg = " | ".join(parts)
_log.info(msg)
log_task_event(self.task_id, msg, level="info")
def step_error(self, step: str, error: str, exc: Exception | None = None) -> None:
msg = f"[{step}] ERROR — {error}"
if exc:
_log.exception(msg)
else:
_log.error(msg)
log_task_event(self.task_id, msg, level="error")
def info(self, step: str, msg: str) -> None:
full = f"[{step}] {msg}"
_log.info(full)
log_task_event(self.task_id, full, level="info")
def warning(self, step: str, msg: str) -> None:
full = f"[{step}] WARNING — {msg}"
_log.warning(full)
log_task_event(self.task_id, full, level="warning")
# ------------------------------------------------------------------
# Context manager for a single step
# ------------------------------------------------------------------
def step(self, step_name: str, context: dict[str, Any] | None = None) -> "_StepContext":
return _StepContext(self, step_name, context)
class _StepContext:
"""Context manager that auto-calls step_start / step_done / step_error."""
def __init__(self, pl: PipelineLogger, step_name: str, context: dict | None):
self._pl = pl
self._name = step_name
self._context = context
def __enter__(self) -> "_StepContext":
self._pl.step_start(self._name, self._context)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self._pl.step_done(self._name)
else:
self._pl.step_error(self._name, str(exc_val), exc_val)
return False # do not suppress exceptions
+39
View File
@@ -0,0 +1,39 @@
"""Named pipeline step identifiers.
All Celery tasks and render scripts reference these constants so that log
messages, DB records, and UI labels stay consistent across the codebase.
"""
from enum import StrEnum
class StepName(StrEnum):
# ── STEP file processing ──────────────────────────────────────────
RESOLVE_STEP_PATH = "resolve_step_path"
OCC_OBJECT_EXTRACT = "occ_object_extract"
OCC_GLB_EXPORT = "occ_glb_export"
GLB_BBOX = "glb_bbox"
MATERIAL_MAP_RESOLVE = "material_map_resolve"
AUTO_POPULATE_MATERIALS = "auto_populate_materials"
# ── Thumbnail generation ─────────────────────────────────────────
BLENDER_RENDER = "blender_render"
THREEJS_RENDER = "threejs_render"
THUMBNAIL_SAVE = "thumbnail_save"
# ── Order line render ─────────────────────────────────────────────
ORDER_LINE_SETUP = "order_line_setup"
RESOLVE_TEMPLATE = "resolve_template"
BLENDER_STILL = "blender_still"
BLENDER_TURNTABLE = "blender_turntable"
OUTPUT_SAVE = "output_save"
# ── GLB / asset export ────────────────────────────────────────────
EXPORT_GLB_GEOMETRY = "export_glb_geometry"
EXPORT_GLB_PRODUCTION = "export_glb_production"
EXPORT_BLEND = "export_blend"
# ── STL cache ────────────────────────────────────────────────────
STL_CACHE_GENERATE = "stl_cache_generate"
# ── Notifications ─────────────────────────────────────────────────
NOTIFY = "notify"
+18 -2
View File
@@ -1,9 +1,13 @@
from typing import AsyncGenerator, Optional
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncGenerator, Optional
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import text
from app.config import settings
if TYPE_CHECKING:
from starlette.requests import Request
engine = create_async_engine(
settings.database_url,
echo=False,
@@ -23,8 +27,20 @@ class Base(DeclarativeBase):
pass
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async def get_db(request: "Request | None" = None) -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session:
# Auto-apply RLS context if TenantContextMiddleware populated request.state
if request is not None:
tenant_id = getattr(request.state, "tenant_id", None)
role = getattr(request.state, "role", None)
if tenant_id:
if role == "admin":
await session.execute(text("SET LOCAL app.current_tenant_id = 'bypass'"))
else:
await session.execute(
text("SET LOCAL app.current_tenant_id = :tid"),
{"tid": tenant_id},
)
try:
yield session
finally:
+1
View File
@@ -131,6 +131,7 @@ class OrderLine(Base):
render_status: Mapped[str] = mapped_column(String(20), nullable=False, default="pending")
result_path: Mapped[str | None] = mapped_column(String(1000), nullable=True)
render_log: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
render_job_doc: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
ai_validation_status: Mapped[str] = mapped_column(String(20), nullable=False, default="not_started")
ai_validation_result: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
flamenco_job_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
@@ -0,0 +1,161 @@
"""RenderJobDocument — structured JSONB job ticket stored in order_lines.render_job_doc.
Acts as the single source of truth for a render job's state machine.
Stored as JSONB in order_lines.render_job_doc; keep order_lines.render_log
for backward compat (deprecated, removed in Phase 3).
Usage::
from app.domains.rendering.job_document import RenderJobDocument, JobState, StepRecord
doc = RenderJobDocument.new(order_line_id=str(line.id), celery_task_id=self.request.id)
doc.begin_step("occ_glb_export")
...
doc.finish_step("occ_glb_export", output={"glb_path": str(glb), "size_bytes": sz})
doc.set_state(JobState.COMPLETED, result={"output_path": str(out)})
# Persist to DB (inside Celery sync task):
line.render_job_doc = doc.to_dict()
db.commit()
"""
import time
from datetime import datetime, timezone
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
# ── State machine ─────────────────────────────────────────────────────────────
class JobState(StrEnum):
PENDING = "pending"
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class StepStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
SKIPPED = "skipped"
# ── Data models ───────────────────────────────────────────────────────────────
class StepRecord(BaseModel):
name: str
status: StepStatus = StepStatus.PENDING
started_at: str | None = None
completed_at: str | None = None
duration_s: float | None = None
output: dict[str, Any] | None = None
error: str | None = None
class RenderJobDocument(BaseModel):
version: int = 1
job_id: str # == order_line_id
created_at: str = Field(default_factory=_now_iso)
updated_at: str = Field(default_factory=_now_iso)
state: JobState = JobState.PENDING
celery_task_id: str | None = None
steps: list[StepRecord] = Field(default_factory=list)
error: str | None = None
result: dict[str, Any] | None = None
# ── Factory ──────────────────────────────────────────────────────
@classmethod
def new(cls, order_line_id: str, celery_task_id: str | None = None) -> "RenderJobDocument":
return cls(job_id=order_line_id, celery_task_id=celery_task_id)
@classmethod
def from_dict(cls, d: dict | None) -> "RenderJobDocument | None":
if not d:
return None
try:
return cls.model_validate(d)
except Exception:
return None
# ── Mutation helpers ─────────────────────────────────────────────
def set_state(self, state: JobState, result: dict[str, Any] | None = None, error: str | None = None) -> None:
self.state = state
self.updated_at = _now_iso()
if result is not None:
self.result = result
if error is not None:
self.error = error
def begin_step(self, step_name: str) -> StepRecord:
"""Mark a step as running. Creates it if not present."""
rec = self._get_or_create_step(step_name)
rec.status = StepStatus.RUNNING
rec.started_at = _now_iso()
self.updated_at = _now_iso()
if self.state == JobState.PENDING or self.state == JobState.QUEUED:
self.state = JobState.RUNNING
return rec
def finish_step(
self,
step_name: str,
output: dict[str, Any] | None = None,
duration_s: float | None = None,
) -> StepRecord:
rec = self._get_or_create_step(step_name)
rec.status = StepStatus.DONE
rec.completed_at = _now_iso()
if duration_s is not None:
rec.duration_s = round(duration_s, 2)
elif rec.started_at:
try:
start = datetime.fromisoformat(rec.started_at)
rec.duration_s = round((datetime.now(timezone.utc) - start).total_seconds(), 2)
except Exception:
pass
if output is not None:
rec.output = output
self.updated_at = _now_iso()
return rec
def fail_step(self, step_name: str, error: str) -> StepRecord:
rec = self._get_or_create_step(step_name)
rec.status = StepStatus.FAILED
rec.completed_at = _now_iso()
rec.error = error
self.updated_at = _now_iso()
return rec
def skip_step(self, step_name: str, reason: str | None = None) -> StepRecord:
rec = self._get_or_create_step(step_name)
rec.status = StepStatus.SKIPPED
if reason:
rec.output = {"reason": reason}
self.updated_at = _now_iso()
return rec
# ── Serialisation ────────────────────────────────────────────────
def to_dict(self) -> dict:
return self.model_dump(mode="json")
# ── Internal ─────────────────────────────────────────────────────
def _get_or_create_step(self, step_name: str) -> StepRecord:
for rec in self.steps:
if rec.name == step_name:
return rec
rec = StepRecord(name=step_name)
self.steps.append(rec)
return rec
+2
View File
@@ -8,6 +8,7 @@ from pathlib import Path
from app.config import settings
from app.database import engine, Base
from app.core.websocket import manager as ws_manager
from app.core.middleware import TenantContextMiddleware
# Import routers from domain locations
from app.domains.auth.router import router as auth_router
@@ -52,6 +53,7 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(TenantContextMiddleware)
# Mount static files for thumbnails (dir created in lifespan; skip if not writable)
thumbnails_dir = Path(settings.upload_dir) / "thumbnails"
+4 -2
View File
@@ -28,9 +28,11 @@ def verify_password(plain: str, hashed: str) -> bool:
return pwd_context.verify(plain, hashed)
def create_access_token(user_id: str, role: str) -> str:
def create_access_token(user_id: str, role: str, tenant_id: str | None = None) -> str:
expires = datetime.utcnow() + timedelta(minutes=settings.jwt_access_token_expire_minutes)
payload = {"sub": user_id, "role": role, "exp": expires}
payload: dict = {"sub": user_id, "role": role, "exp": expires}
if tenant_id:
payload["tenant_id"] = tenant_id
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)