Files
HartOMat/backend/tests/db_test_utils.py
T

86 lines
3.0 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
import importlib
import os
from typing import Iterator
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.orm import Session
from sqlalchemy import create_engine
from app.database import Base
def resolve_test_db_url(*, async_driver: bool) -> str:
explicit_url = os.environ.get("TEST_DATABASE_URL")
if explicit_url:
db_url = explicit_url
else:
host = os.environ.get("TEST_POSTGRES_HOST") or os.environ.get("POSTGRES_HOST") or "localhost"
port = os.environ.get("TEST_POSTGRES_PORT") or os.environ.get("POSTGRES_PORT") or "5432"
user = os.environ.get("TEST_POSTGRES_USER") or os.environ.get("POSTGRES_USER") or "hartomat"
password = os.environ.get("TEST_POSTGRES_PASSWORD") or os.environ.get("POSTGRES_PASSWORD") or "hartomat"
default_db = f"{os.environ.get('POSTGRES_DB', 'hartomat')}_test"
database = os.environ.get("TEST_POSTGRES_DB") or os.environ.get("TEST_DB_NAME") or default_db
driver = "postgresql+asyncpg" if async_driver else "postgresql"
db_url = f"{driver}://{user}:{password}@{host}:{port}/{database}"
normalized_url = db_url if async_driver else db_url.replace("+asyncpg", "")
database_name = make_url(normalized_url).database or ""
if not database_name.endswith("_test"):
raise RuntimeError(
f"Refusing to run destructive test database setup against non-test database '{database_name}'."
)
return normalized_url
def reset_public_schema_sync(connection) -> None:
connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
connection.execute(text("CREATE SCHEMA public"))
async def reset_public_schema_async(connection) -> None:
await connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
await connection.execute(text("CREATE SCHEMA public"))
def import_all_model_modules() -> None:
module_names = (
"app.domains.tenants.models",
"app.domains.auth.models",
"app.domains.imports.models",
"app.domains.products.models",
"app.domains.orders.models",
"app.domains.notifications.models",
"app.domains.billing.models",
"app.domains.rendering.models",
"app.domains.materials.models",
"app.domains.media.models",
"app.domains.admin.models",
"app.models.system_setting",
"app.models.worker_config",
"app.models.chat",
)
for module_name in module_names:
importlib.import_module(module_name)
@contextmanager
def sync_test_session() -> Iterator[Session]:
import_all_model_modules()
engine = create_engine(resolve_test_db_url(async_driver=False))
with engine.begin() as conn:
reset_public_schema_sync(conn)
Base.metadata.create_all(conn)
session = Session(engine)
try:
yield session
finally:
session.close()
with engine.begin() as conn:
reset_public_schema_sync(conn)
engine.dispose()