11from collections .abc import AsyncGenerator
2+ from types import SimpleNamespace
23from typing import Any
34
45import pytest
56from httpx import ASGITransport , AsyncClient
7+ from sqlalchemy import text
8+ from sqlalchemy .exc import ProgrammingError
69
7- from app .database import engine
10+ from app .database import engine , test_engine , get_test_db , get_db
811from app .main import app
912from app .models .base import Base
1013from app .redis import get_redis
1922def anyio_backend (request ):
2023 return request .param
2124
25+ def _create_db (conn ) -> None :
26+ """Create a database schema if it doesn't exist."""
27+ try :
28+ conn .execute (text ("CREATE DATABASE testdb" ))
29+ except ProgrammingError :
30+ # This might be raised by databases that don't support `IF NOT EXISTS`
31+ # and the schema already exists. You can choose to ignore it.
32+ pass
33+
34+
35+ def _create_db_schema (conn ) -> None :
36+ """Create a database schema if it doesn't exist."""
37+ try :
38+ conn .execute (text ("CREATE SCHEMA happy_hog" ))
39+ conn .execute (text ("CREATE SCHEMA shakespeare" ))
40+ except ProgrammingError :
41+ # This might be raised by databases that don't support `IF NOT EXISTS`
42+ # and the schema already exists. You can choose to ignore it.
43+ pass
44+
2245
2346@pytest .fixture (scope = "session" )
2447async def start_db ():
25- async with engine .begin () as conn :
48+ # The `engine` is configured for the default 'postgres' database.
49+ # We connect to it and create the test database.
50+ # A transaction block is not used, as CREATE DATABASE cannot run inside it.
51+ async with engine .connect () as conn :
52+ await conn .execute (text ("COMMIT" )) # Ensure we're not in a transaction
53+ await conn .run_sync (_create_db )
54+
55+ # Now, connect to the newly created `testdb` with `test_engine`
56+ async with test_engine .begin () as conn :
57+ await conn .execute (text ("COMMIT" )) # Ensure we're not in a transaction
58+ await conn .run_sync (_create_db_schema )
2659 await conn .run_sync (Base .metadata .drop_all )
2760 await conn .run_sync (Base .metadata .create_all )
2861 # for AsyncEngine created in function scope, close and
2962 # clean-up pooled connections
3063 await engine .dispose ()
64+ await test_engine .dispose ()
3165
3266
3367@pytest .fixture (scope = "session" )
@@ -40,5 +74,6 @@ async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
4074 headers = {"Content-Type" : "application/json" },
4175 transport = transport ,
4276 ) as test_client :
77+ app .dependency_overrides [get_db ] = get_test_db
4378 app .redis = await get_redis ()
4479 yield test_client
0 commit comments