86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
import importlib
|
|
import os
|
|
from typing import Iterator
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.engine import make_url
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import create_engine
|
|
|
|
from app.database import Base
|
|
|
|
|
|
def resolve_test_db_url(*, async_driver: bool) -> str:
|
|
explicit_url = os.environ.get("TEST_DATABASE_URL")
|
|
if explicit_url:
|
|
db_url = explicit_url
|
|
else:
|
|
host = os.environ.get("TEST_POSTGRES_HOST") or os.environ.get("POSTGRES_HOST") or "localhost"
|
|
port = os.environ.get("TEST_POSTGRES_PORT") or os.environ.get("POSTGRES_PORT") or "5432"
|
|
user = os.environ.get("TEST_POSTGRES_USER") or os.environ.get("POSTGRES_USER") or "hartomat"
|
|
password = os.environ.get("TEST_POSTGRES_PASSWORD") or os.environ.get("POSTGRES_PASSWORD") or "hartomat"
|
|
default_db = f"{os.environ.get('POSTGRES_DB', 'hartomat')}_test"
|
|
database = os.environ.get("TEST_POSTGRES_DB") or os.environ.get("TEST_DB_NAME") or default_db
|
|
driver = "postgresql+asyncpg" if async_driver else "postgresql"
|
|
db_url = f"{driver}://{user}:{password}@{host}:{port}/{database}"
|
|
|
|
normalized_url = db_url if async_driver else db_url.replace("+asyncpg", "")
|
|
database_name = make_url(normalized_url).database or ""
|
|
if not database_name.endswith("_test"):
|
|
raise RuntimeError(
|
|
f"Refusing to run destructive test database setup against non-test database '{database_name}'."
|
|
)
|
|
return normalized_url
|
|
|
|
|
|
def reset_public_schema_sync(connection) -> None:
|
|
connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
|
|
connection.execute(text("CREATE SCHEMA public"))
|
|
|
|
|
|
async def reset_public_schema_async(connection) -> None:
|
|
await connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
|
|
await connection.execute(text("CREATE SCHEMA public"))
|
|
|
|
|
|
def import_all_model_modules() -> None:
|
|
module_names = (
|
|
"app.domains.tenants.models",
|
|
"app.domains.auth.models",
|
|
"app.domains.imports.models",
|
|
"app.domains.products.models",
|
|
"app.domains.orders.models",
|
|
"app.domains.notifications.models",
|
|
"app.domains.billing.models",
|
|
"app.domains.rendering.models",
|
|
"app.domains.materials.models",
|
|
"app.domains.media.models",
|
|
"app.domains.admin.models",
|
|
"app.models.system_setting",
|
|
"app.models.worker_config",
|
|
"app.models.chat",
|
|
)
|
|
for module_name in module_names:
|
|
importlib.import_module(module_name)
|
|
|
|
|
|
@contextmanager
|
|
def sync_test_session() -> Iterator[Session]:
|
|
import_all_model_modules()
|
|
engine = create_engine(resolve_test_db_url(async_driver=False))
|
|
with engine.begin() as conn:
|
|
reset_public_schema_sync(conn)
|
|
Base.metadata.create_all(conn)
|
|
|
|
session = Session(engine)
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
with engine.begin() as conn:
|
|
reset_public_schema_sync(conn)
|
|
engine.dispose()
|