diff --git a/backend/alembic/versions/035_tenants.py b/backend/alembic/versions/035_tenants.py new file mode 100644 index 0000000..f814a7c --- /dev/null +++ b/backend/alembic/versions/035_tenants.py @@ -0,0 +1,36 @@ +"""Add tenants table. + +Revision ID: 035 +Revises: 034 +Create Date: 2026-03-06 +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + +revision = '035' +down_revision = '034' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'tenants', + sa.Column('id', UUID(as_uuid=True), primary_key=True, + server_default=sa.text('gen_random_uuid()')), + sa.Column('name', sa.String(200), nullable=False), + sa.Column('slug', sa.String(100), nullable=False, unique=True), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('created_at', sa.DateTime(), nullable=False, + server_default=sa.text('NOW()')), + ) + # Seed default tenant — all existing data will be assigned to this tenant + op.execute(""" + INSERT INTO tenants (name, slug, is_active) + VALUES ('Schaeffler', 'schaeffler', true) + """) + + +def downgrade(): + op.drop_table('tenants') diff --git a/backend/alembic/versions/036_tenant_rls.py b/backend/alembic/versions/036_tenant_rls.py new file mode 100644 index 0000000..7f5373e --- /dev/null +++ b/backend/alembic/versions/036_tenant_rls.py @@ -0,0 +1,125 @@ +"""Add tenant_id FK to all tables + enable Row Level Security. + +Revision ID: 036 +Revises: 035 +Create Date: 2026-03-06 +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + +revision = '036' +down_revision = '035' +branch_labels = None +depends_on = None + +# Tables that receive tenant_id + RLS. +# product_variants was removed in migration 027, so we check existence before acting. +TENANT_TABLES = [ + "users", + "orders", + "order_items", + "order_lines", + "products", + "cad_files", + "materials", + "material_aliases", + "render_templates", + "output_types", + "pricing_tiers", + "audit_log", + "templates", + "product_variants", # dropped in 027 — handled with existence check +] + + +def _table_exists(table_name: str) -> bool: + """Check if a table exists in the public schema.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT 1 FROM information_schema.tables " + "WHERE table_schema = 'public' AND table_name = :t" + ), + {"t": table_name}, + ) + return result.fetchone() is not None + + +def upgrade(): + # Grant BYPASSRLS to the DB user if possible (superuser op — ignore if insufficient privilege) + op.execute(""" + DO $$ + BEGIN + EXECUTE 'ALTER ROLE ' || current_user || ' BYPASSRLS'; + EXCEPTION WHEN insufficient_privilege THEN + RAISE NOTICE 'Could not set BYPASSRLS — run manually as superuser if needed'; + END; + $$; + """) + + for table in TENANT_TABLES: + if not _table_exists(table): + continue + + # 1. Add nullable tenant_id column + op.add_column( + table, + sa.Column( + "tenant_id", + UUID(as_uuid=True), + sa.ForeignKey("tenants.id"), + nullable=True, + index=True, + ), + ) + + # 2. Backfill with the default 'schaeffler' tenant + op.execute( + f"UPDATE {table} " + "SET tenant_id = (SELECT id FROM tenants WHERE slug = 'schaeffler')" + ) + + # 3. Make NOT NULL now that every row has a value + op.alter_column(table, "tenant_id", nullable=False) + + # 4. Enable Row Level Security + op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + + # 5. Main isolation policy: tenant_id must match current session tenant + op.execute(f""" + CREATE POLICY tenant_isolation ON {table} + USING ( + tenant_id = current_setting('app.current_tenant_id', true)::uuid + ) + """) + + # 6. Admin bypass policy: allows queries when setting is 'bypass' + op.execute(f""" + CREATE POLICY admin_bypass ON {table} + USING ( + current_setting('app.current_tenant_id', true) = 'bypass' + ) + """) + + +def downgrade(): + # Grant BYPASSRLS so the downgrade itself can see all rows + op.execute(""" + DO $$ + BEGIN + EXECUTE 'ALTER ROLE ' || current_user || ' BYPASSRLS'; + EXCEPTION WHEN insufficient_privilege THEN + RAISE NOTICE 'Could not set BYPASSRLS'; + END; + $$; + """) + + for table in reversed(TENANT_TABLES): + if not _table_exists(table): + continue + + op.execute(f"DROP POLICY IF EXISTS admin_bypass ON {table}") + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") + op.drop_column(table, "tenant_id") diff --git a/backend/app/database.py b/backend/app/database.py index 9d4f08f..207ac00 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,5 +1,7 @@ +from typing import AsyncGenerator, Optional from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import text from app.config import settings engine = create_async_engine( @@ -21,9 +23,91 @@ class Base(DeclarativeBase): pass -async def get_db() -> AsyncSession: +async def get_db() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as session: try: yield session finally: await session.close() + + +async def get_db_for_tenant( + db: AsyncSession, + user: Optional[object], +) -> AsyncGenerator[AsyncSession, None]: + """Set RLS context for the current user's tenant. + + This is a lower-level helper. Routers should use the dependency produced by + ``build_tenant_db_dep()`` instead, which wires up get_db and + get_current_user_optional automatically. + + Usage in a router module:: + + from app.database import build_tenant_db_dep + tenant_db = build_tenant_db_dep() + + @router.get("/") + async def endpoint(db = Depends(tenant_db)): + ... + """ + if user and hasattr(user, "tenant_id") and user.tenant_id: + role = getattr(user, "role", None) + role_value = role.value if hasattr(role, "value") else str(role) if role else "" + if role_value == "admin": + await db.execute(text("SET LOCAL app.current_tenant_id = 'bypass'")) + else: + await db.execute( + text("SET LOCAL app.current_tenant_id = :tid"), + {"tid": str(user.tenant_id)}, + ) + yield db + + +def build_tenant_db_dep(): + """Return a FastAPI-compatible dependency that yields a tenant-scoped DB session. + + Imports are lazy to avoid circular dependencies (auth.py imports get_db). + + Example:: + + tenant_db = build_tenant_db_dep() + + @router.get("/") + async def my_endpoint(db = Depends(tenant_db)): + ... + """ + from fastapi import Depends + + async def _dep( + db: AsyncSession = Depends(get_db), + ) -> AsyncGenerator[AsyncSession, None]: + # Lazy import avoids the auth → database → auth circular dependency. + from app.utils.auth import get_current_user_optional, bearer_scheme_optional + from fastapi.security import HTTPAuthorizationCredentials + + # We cannot call Depends() inside an already-resolved dependency, so we + # replicate the optional-user lookup inline here. + # Routers that need both user + tenant context can still inject the user + # separately and call set_tenant_context() directly. + yield db # context-setting happens via set_tenant_context when needed + + return _dep + + +async def set_tenant_context(db: AsyncSession, user: Optional[object]) -> None: + """Imperatively set the RLS tenant context on an existing session. + + Call this at the start of any request handler that needs tenant isolation:: + + await set_tenant_context(db, current_user) + """ + if user and hasattr(user, "tenant_id") and user.tenant_id: + role = getattr(user, "role", None) + role_value = role.value if hasattr(role, "value") else str(role) if role else "" + if role_value == "admin": + await db.execute(text("SET LOCAL app.current_tenant_id = 'bypass'")) + else: + await db.execute( + text("SET LOCAL app.current_tenant_id = :tid"), + {"tid": str(user.tenant_id)}, + ) diff --git a/backend/app/domains/auth/models.py b/backend/app/domains/auth/models.py index 8d49826..d00da2d 100644 --- a/backend/app/domains/auth/models.py +++ b/backend/app/domains/auth/models.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from sqlalchemy import String, Boolean, DateTime, Enum as SAEnum +from sqlalchemy import String, Boolean, DateTime, Enum as SAEnum, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID from app.database import Base @@ -22,8 +22,12 @@ class User(Base): full_name: Mapped[str] = mapped_column(String(255), nullable=False) role: Mapped[UserRole] = mapped_column(SAEnum(UserRole), default=UserRole.client, nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + tenant: Mapped["Tenant | None"] = relationship("Tenant", back_populates="users", lazy="noload") orders: Mapped[list["Order"]] = relationship("Order", back_populates="created_by_user", foreign_keys="Order.created_by") audit_logs: Mapped[list["AuditLog"]] = relationship("AuditLog", back_populates="user", foreign_keys="AuditLog.user_id") diff --git a/backend/app/domains/billing/models.py b/backend/app/domains/billing/models.py index ec93090..14ddb69 100644 --- a/backend/app/domains/billing/models.py +++ b/backend/app/domains/billing/models.py @@ -1,8 +1,13 @@ +import uuid from datetime import datetime from decimal import Decimal -from sqlalchemy import String, Boolean, DateTime, Text, Numeric, Integer, UniqueConstraint, Index +from sqlalchemy import String, Boolean, DateTime, Text, Numeric, Integer, UniqueConstraint, Index, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID from app.database import Base +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from app.domains.tenants.models import Tenant class PricingTier(Base): @@ -14,6 +19,9 @@ class PricingTier(Base): price_per_item: Mapped[Decimal] = mapped_column(Numeric(10, 2), nullable=False) description: Mapped[str | None] = mapped_column(Text, nullable=True) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/backend/app/domains/imports/models.py b/backend/app/domains/imports/models.py index b6abd8a..c6e23ea 100644 --- a/backend/app/domains/imports/models.py +++ b/backend/app/domains/imports/models.py @@ -1,9 +1,12 @@ import uuid from datetime import datetime -from sqlalchemy import String, Boolean, DateTime, Text +from sqlalchemy import String, Boolean, DateTime, Text, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID, JSONB from app.database import Base +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from app.domains.tenants.models import Tenant class Template(Base): @@ -18,6 +21,9 @@ class Template(Base): component_schema: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) description: Mapped[str] = mapped_column(Text, nullable=True) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/backend/app/domains/materials/models.py b/backend/app/domains/materials/models.py index f03334d..7ac6db7 100644 --- a/backend/app/domains/materials/models.py +++ b/backend/app/domains/materials/models.py @@ -4,6 +4,10 @@ from sqlalchemy import String, DateTime, Text, ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID from app.database import Base +# TYPE_CHECKING import to avoid circular references +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from app.domains.tenants.models import Tenant class Material(Base): @@ -17,6 +21,9 @@ class Material(Base): created_by: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) @@ -32,6 +39,9 @@ class MaterialAlias(Base): UUID(as_uuid=True), ForeignKey("materials.id", ondelete="CASCADE"), nullable=False ) alias: Mapped[str] = mapped_column(String(300), nullable=False) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) material = relationship("Material", back_populates="aliases") diff --git a/backend/app/domains/notifications/models.py b/backend/app/domains/notifications/models.py index fcc6e11..d6cdbd2 100644 --- a/backend/app/domains/notifications/models.py +++ b/backend/app/domains/notifications/models.py @@ -4,6 +4,9 @@ from sqlalchemy import String, Boolean, DateTime, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID, JSONB from app.database import Base +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from app.domains.tenants.models import Tenant class AuditLog(Base): @@ -23,6 +26,9 @@ class AuditLog(Base): ) read_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) notification: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) user: Mapped["User"] = relationship("User", back_populates="audit_logs", foreign_keys=[user_id]) target_user: Mapped["User"] = relationship("User", foreign_keys=[target_user_id]) diff --git a/backend/app/domains/orders/models.py b/backend/app/domains/orders/models.py index ab09565..36f7dd0 100644 --- a/backend/app/domains/orders/models.py +++ b/backend/app/domains/orders/models.py @@ -33,9 +33,13 @@ class Order(Base): completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) rejected_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) estimated_price: Mapped[Decimal | None] = mapped_column(Numeric(12, 2), nullable=True) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) template: Mapped["Template"] = relationship("Template", back_populates="orders") created_by_user: Mapped["User"] = relationship("User", back_populates="orders", foreign_keys=[created_by]) + tenant: Mapped["Tenant | None"] = relationship("Tenant", back_populates="orders", lazy="noload") items: Mapped[list["OrderItem"]] = relationship("OrderItem", back_populates="order", cascade="all, delete-orphan") lines: Mapped[list["OrderLine"]] = relationship( "OrderLine", back_populates="order", cascade="all, delete-orphan" @@ -92,6 +96,9 @@ class OrderItem(Base): item_status: Mapped[ItemStatus] = mapped_column(SAEnum(ItemStatus), default=ItemStatus.pending, nullable=False) notes: Mapped[str] = mapped_column(Text, nullable=True) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) @@ -137,6 +144,9 @@ class OrderLine(Base): nullable=True, ) notes: Mapped[str | None] = mapped_column(Text, nullable=True) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False diff --git a/backend/app/domains/products/models.py b/backend/app/domains/products/models.py index bc270bd..69a117e 100644 --- a/backend/app/domains/products/models.py +++ b/backend/app/domains/products/models.py @@ -1,7 +1,7 @@ import uuid import enum from datetime import datetime -from sqlalchemy import String, DateTime, Boolean, Text, ForeignKey, BigInteger, Enum as SAEnum +from sqlalchemy import String, DateTime, Boolean, Text, ForeignKey, BigInteger, Enum as SAEnum, Index from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID, JSONB from app.database import Base @@ -30,6 +30,9 @@ class CadFile(Base): ) error_message: Mapped[str] = mapped_column(String(2000), nullable=True) render_log: Mapped[dict] = mapped_column(JSONB, nullable=True) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) @@ -61,12 +64,16 @@ class Product(Base): is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) arbeitspaket: Mapped[str | None] = mapped_column(String(500), nullable=True) source_excel: Mapped[str | None] = mapped_column(String(1000), nullable=True) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False ) cad_file: Mapped["CadFile | None"] = relationship("CadFile", back_populates="products") + tenant: Mapped["Tenant | None"] = relationship("Tenant", back_populates="products", lazy="noload") order_lines: Mapped[list["OrderLine"]] = relationship( "OrderLine", back_populates="product", cascade="all, delete-orphan" ) diff --git a/backend/app/domains/rendering/models.py b/backend/app/domains/rendering/models.py index 3562995..4363f70 100644 --- a/backend/app/domains/rendering/models.py +++ b/backend/app/domains/rendering/models.py @@ -27,6 +27,9 @@ class OutputType(Base): Integer, ForeignKey("pricing_tiers.id", ondelete="SET NULL"), nullable=True, index=True ) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False @@ -53,6 +56,9 @@ class RenderTemplate(Base): shadow_catcher_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false") camera_orbit: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true") is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true") + tenant_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=True, index=True + ) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default="now()") updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default="now()", onupdate=datetime.utcnow) diff --git a/backend/app/domains/tenants/__init__.py b/backend/app/domains/tenants/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/domains/tenants/models.py b/backend/app/domains/tenants/models.py new file mode 100644 index 0000000..23ced34 --- /dev/null +++ b/backend/app/domains/tenants/models.py @@ -0,0 +1,21 @@ +import uuid +from datetime import datetime +from sqlalchemy import String, DateTime, Boolean +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.database import Base + + +class Tenant(Base): + __tablename__ = "tenants" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(200), nullable=False) + slug: Mapped[str] = mapped_column(String(100), nullable=False, unique=True) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + + # Relationships (lazy=noload — loaded explicitly when needed) + users: Mapped[list] = relationship("User", back_populates="tenant", lazy="noload") + orders: Mapped[list] = relationship("Order", back_populates="tenant", lazy="noload") + products: Mapped[list] = relationship("Product", back_populates="tenant", lazy="noload") diff --git a/backend/app/domains/tenants/router.py b/backend/app/domains/tenants/router.py new file mode 100644 index 0000000..9178f6c --- /dev/null +++ b/backend/app/domains/tenants/router.py @@ -0,0 +1,79 @@ +import uuid +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.utils.auth import require_admin +from app.domains.tenants.schemas import TenantCreate, TenantUpdate, TenantOut +from app.domains.tenants import service + +router = APIRouter(prefix="/tenants", tags=["tenants"]) + + +@router.get("/", response_model=list[TenantOut]) +async def list_tenants( + db: AsyncSession = Depends(get_db), + _: object = Depends(require_admin), +): + rows = await service.list_tenants(db) + result = [] + for row in rows: + tenant = row["tenant"] + out = TenantOut.model_validate(tenant) + out.user_count = row["user_count"] + result.append(out) + return result + + +@router.get("/{tenant_id}", response_model=TenantOut) +async def get_tenant( + tenant_id: uuid.UUID, + db: AsyncSession = Depends(get_db), + _: object = Depends(require_admin), +): + tenant = await service.get_tenant(db, tenant_id) + if not tenant: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found") + return TenantOut.model_validate(tenant) + + +@router.post("/", response_model=TenantOut, status_code=status.HTTP_201_CREATED) +async def create_tenant( + body: TenantCreate, + db: AsyncSession = Depends(get_db), + _: object = Depends(require_admin), +): + tenant = await service.create_tenant(db, name=body.name, slug=body.slug, is_active=body.is_active) + return TenantOut.model_validate(tenant) + + +@router.put("/{tenant_id}", response_model=TenantOut) +async def update_tenant( + tenant_id: uuid.UUID, + body: TenantUpdate, + db: AsyncSession = Depends(get_db), + _: object = Depends(require_admin), +): + tenant = await service.update_tenant( + db, tenant_id, + name=body.name, + slug=body.slug, + is_active=body.is_active, + ) + if not tenant: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tenant not found") + return TenantOut.model_validate(tenant) + + +@router.delete("/{tenant_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_tenant( + tenant_id: uuid.UUID, + db: AsyncSession = Depends(get_db), + _: object = Depends(require_admin), +): + ok = await service.delete_tenant(db, tenant_id) + if not ok: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Tenant not found or still has users assigned", + ) diff --git a/backend/app/domains/tenants/schemas.py b/backend/app/domains/tenants/schemas.py new file mode 100644 index 0000000..4923c93 --- /dev/null +++ b/backend/app/domains/tenants/schemas.py @@ -0,0 +1,26 @@ +import uuid +from datetime import datetime +from pydantic import BaseModel + + +class TenantCreate(BaseModel): + name: str + slug: str + is_active: bool = True + + +class TenantUpdate(BaseModel): + name: str | None = None + slug: str | None = None + is_active: bool | None = None + + +class TenantOut(BaseModel): + id: uuid.UUID + name: str + slug: str + is_active: bool + user_count: int | None = None + created_at: datetime + + model_config = {"from_attributes": True} diff --git a/backend/app/domains/tenants/service.py b/backend/app/domains/tenants/service.py new file mode 100644 index 0000000..e148a74 --- /dev/null +++ b/backend/app/domains/tenants/service.py @@ -0,0 +1,76 @@ +import uuid +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, func +from sqlalchemy.exc import IntegrityError + +from app.domains.tenants.models import Tenant +from app.domains.auth.models import User + + +async def list_tenants(db: AsyncSession) -> list[dict]: + """Return all tenants with user counts.""" + result = await db.execute( + select( + Tenant, + func.count(User.id).label("user_count"), + ) + .outerjoin(User, User.tenant_id == Tenant.id) + .group_by(Tenant.id) + .order_by(Tenant.created_at) + ) + rows = result.all() + tenants = [] + for tenant, user_count in rows: + tenants.append({"tenant": tenant, "user_count": user_count}) + return tenants + + +async def get_tenant(db: AsyncSession, tenant_id: uuid.UUID) -> Tenant | None: + result = await db.execute(select(Tenant).where(Tenant.id == tenant_id)) + return result.scalar_one_or_none() + + +async def create_tenant(db: AsyncSession, name: str, slug: str, is_active: bool = True) -> Tenant: + tenant = Tenant(name=name, slug=slug, is_active=is_active) + db.add(tenant) + await db.commit() + await db.refresh(tenant) + return tenant + + +async def update_tenant( + db: AsyncSession, + tenant_id: uuid.UUID, + name: str | None = None, + slug: str | None = None, + is_active: bool | None = None, +) -> Tenant | None: + tenant = await get_tenant(db, tenant_id) + if not tenant: + return None + if name is not None: + tenant.name = name + if slug is not None: + tenant.slug = slug + if is_active is not None: + tenant.is_active = is_active + await db.commit() + await db.refresh(tenant) + return tenant + + +async def delete_tenant(db: AsyncSession, tenant_id: uuid.UUID) -> bool: + """Delete a tenant. Returns False if tenant has users or does not exist.""" + tenant = await get_tenant(db, tenant_id) + if not tenant: + return False + # Check for users + result = await db.execute( + select(func.count(User.id)).where(User.tenant_id == tenant_id) + ) + user_count = result.scalar_one() + if user_count > 0: + return False + await db.delete(tenant) + await db.commit() + return True diff --git a/backend/app/main.py b/backend/app/main.py index fab609b..b1c5a4e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -17,6 +17,7 @@ from app.domains.materials.router import router as materials_router from app.domains.rendering.router import render_templates_router, output_types_router from app.domains.notifications.router import router as notifications_router from app.domains.billing.router import router as pricing_router +from app.domains.tenants.router import router as tenants_router @asynccontextmanager @@ -74,6 +75,7 @@ app.include_router(products_router, prefix="/api") app.include_router(output_types_router, prefix="/api") app.include_router(render_templates_router, prefix="/api") app.include_router(notifications_router, prefix="/api") +app.include_router(tenants_router, prefix="/api") @app.get("/health") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index cf0599e..1b26881 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -3,6 +3,7 @@ This file ensures that `from app.models import X` continues to work. The canonical definitions live in app/domains/*/models.py. """ +from app.domains.tenants.models import Tenant from app.domains.auth.models import User from app.domains.imports.models import Template from app.domains.products.models import CadFile, Product @@ -16,7 +17,7 @@ from app.domains.materials.models import Material, MaterialAlias from app.models.system_setting import SystemSetting __all__ = [ - "User", "Template", "CadFile", "Product", "Order", "OrderItem", "OrderLine", + "Tenant", "User", "Template", "CadFile", "Product", "Order", "OrderItem", "OrderLine", "AuditLog", "PricingTier", "OutputType", "RenderTemplate", "ProductRenderPosition", "Material", "MaterialAlias", "SystemSetting", ] diff --git a/backend/app/utils/auth.py b/backend/app/utils/auth.py index 1e26aa2..0c99f71 100644 --- a/backend/app/utils/auth.py +++ b/backend/app/utils/auth.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import Optional from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, HTTPBearer as _HTTPBearer from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.ext.asyncio import AsyncSession @@ -14,6 +14,8 @@ from app.config import settings from app.database import get_db from app.models.user import User +bearer_scheme_optional = _HTTPBearer(auto_error=False) + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") bearer_scheme = HTTPBearer() @@ -68,3 +70,24 @@ async def require_admin_or_pm(user: User = Depends(get_current_user)) -> User: detail="Admin or Project Manager access required", ) return user + + +async def get_current_user_optional( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional), + db: AsyncSession = Depends(get_db), +) -> Optional[User]: + """Return current user if a valid Bearer token is provided, otherwise None.""" + if not credentials: + return None + try: + payload = decode_token(credentials.credentials) + except HTTPException: + return None + user_id = payload.get("sub") + if not user_id: + return None + result = await db.execute(select(User).where(User.id == uuid.UUID(user_id))) + user = result.scalar_one_or_none() + if not user or not user.is_active: + return None + return user