chore: snapshot workflow migration progress
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import importlib
|
||||
import os
|
||||
from typing import Iterator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
def resolve_test_db_url(*, async_driver: bool) -> str:
|
||||
explicit_url = os.environ.get("TEST_DATABASE_URL")
|
||||
if explicit_url:
|
||||
db_url = explicit_url
|
||||
else:
|
||||
host = os.environ.get("TEST_POSTGRES_HOST") or os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
port = os.environ.get("TEST_POSTGRES_PORT") or os.environ.get("POSTGRES_PORT") or "5432"
|
||||
user = os.environ.get("TEST_POSTGRES_USER") or os.environ.get("POSTGRES_USER") or "hartomat"
|
||||
password = os.environ.get("TEST_POSTGRES_PASSWORD") or os.environ.get("POSTGRES_PASSWORD") or "hartomat"
|
||||
default_db = f"{os.environ.get('POSTGRES_DB', 'hartomat')}_test"
|
||||
database = os.environ.get("TEST_POSTGRES_DB") or os.environ.get("TEST_DB_NAME") or default_db
|
||||
driver = "postgresql+asyncpg" if async_driver else "postgresql"
|
||||
db_url = f"{driver}://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
normalized_url = db_url if async_driver else db_url.replace("+asyncpg", "")
|
||||
database_name = make_url(normalized_url).database or ""
|
||||
if not database_name.endswith("_test"):
|
||||
raise RuntimeError(
|
||||
f"Refusing to run destructive test database setup against non-test database '{database_name}'."
|
||||
)
|
||||
return normalized_url
|
||||
|
||||
|
||||
def reset_public_schema_sync(connection) -> None:
|
||||
connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
|
||||
connection.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
|
||||
async def reset_public_schema_async(connection) -> None:
|
||||
await connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
|
||||
await connection.execute(text("CREATE SCHEMA public"))
|
||||
|
||||
|
||||
def import_all_model_modules() -> None:
|
||||
module_names = (
|
||||
"app.domains.tenants.models",
|
||||
"app.domains.auth.models",
|
||||
"app.domains.imports.models",
|
||||
"app.domains.products.models",
|
||||
"app.domains.orders.models",
|
||||
"app.domains.notifications.models",
|
||||
"app.domains.billing.models",
|
||||
"app.domains.rendering.models",
|
||||
"app.domains.materials.models",
|
||||
"app.domains.media.models",
|
||||
"app.domains.admin.models",
|
||||
"app.models.system_setting",
|
||||
"app.models.worker_config",
|
||||
"app.models.chat",
|
||||
)
|
||||
for module_name in module_names:
|
||||
importlib.import_module(module_name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sync_test_session() -> Iterator[Session]:
|
||||
import_all_model_modules()
|
||||
engine = create_engine(resolve_test_db_url(async_driver=False))
|
||||
with engine.begin() as conn:
|
||||
reset_public_schema_sync(conn)
|
||||
Base.metadata.create_all(conn)
|
||||
|
||||
session = Session(engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
with engine.begin() as conn:
|
||||
reset_public_schema_sync(conn)
|
||||
engine.dispose()
|
||||
Reference in New Issue
Block a user