from __future__ import annotations from contextlib import contextmanager import importlib import os from typing import Iterator from sqlalchemy import text from sqlalchemy.engine import make_url from sqlalchemy.orm import Session from sqlalchemy import create_engine from app.database import Base def resolve_test_db_url(*, async_driver: bool) -> str: explicit_url = os.environ.get("TEST_DATABASE_URL") if explicit_url: db_url = explicit_url else: host = os.environ.get("TEST_POSTGRES_HOST") or os.environ.get("POSTGRES_HOST") or "localhost" port = os.environ.get("TEST_POSTGRES_PORT") or os.environ.get("POSTGRES_PORT") or "5432" user = os.environ.get("TEST_POSTGRES_USER") or os.environ.get("POSTGRES_USER") or "hartomat" password = os.environ.get("TEST_POSTGRES_PASSWORD") or os.environ.get("POSTGRES_PASSWORD") or "hartomat" default_db = f"{os.environ.get('POSTGRES_DB', 'hartomat')}_test" database = os.environ.get("TEST_POSTGRES_DB") or os.environ.get("TEST_DB_NAME") or default_db driver = "postgresql+asyncpg" if async_driver else "postgresql" db_url = f"{driver}://{user}:{password}@{host}:{port}/{database}" normalized_url = db_url if async_driver else db_url.replace("+asyncpg", "") database_name = make_url(normalized_url).database or "" if not database_name.endswith("_test"): raise RuntimeError( f"Refusing to run destructive test database setup against non-test database '{database_name}'." ) return normalized_url def reset_public_schema_sync(connection) -> None: connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE")) connection.execute(text("CREATE SCHEMA public")) async def reset_public_schema_async(connection) -> None: await connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE")) await connection.execute(text("CREATE SCHEMA public")) def import_all_model_modules() -> None: module_names = ( "app.domains.tenants.models", "app.domains.auth.models", "app.domains.imports.models", "app.domains.products.models", "app.domains.orders.models", "app.domains.notifications.models", "app.domains.billing.models", "app.domains.rendering.models", "app.domains.materials.models", "app.domains.media.models", "app.domains.admin.models", "app.models.system_setting", "app.models.worker_config", "app.models.chat", ) for module_name in module_names: importlib.import_module(module_name) @contextmanager def sync_test_session() -> Iterator[Session]: import_all_model_modules() engine = create_engine(resolve_test_db_url(async_driver=False)) with engine.begin() as conn: reset_public_schema_sync(conn) Base.metadata.create_all(conn) session = Session(engine) try: yield session finally: session.close() with engine.begin() as conn: reset_public_schema_sync(conn) engine.dispose()