33
44import pytest
55from httpx import ASGITransport , AsyncClient
6+ from sqlalchemy import text
7+ from sqlalchemy .exc import ProgrammingError
68
7- from app .database import engine
9+ from app .database import engine , get_db , get_test_db , test_engine
810from app .main import app
911from app .models .base import Base
1012from app .redis import get_redis
1921def anyio_backend (request ):
2022 return request .param
2123
24+ def _create_db (conn ) -> None :
25+ """Create the test database if it doesn't exist."""
26+ try :
27+ conn .execute (text ("CREATE DATABASE testdb" ))
28+ except ProgrammingError :
29+ # This might be raised by databases that don't support `IF NOT EXISTS`
30+ # and the schema already exists. You can choose to ignore it.
31+ pass
32+
33+
34+ def _create_db_schema (conn ) -> None :
35+ """Create a database schema if it doesn't exist."""
36+ try :
37+ """Create a database schema if it doesn't exist."""
38+ conn .execute (text ("CREATE SCHEMA IF NOT EXISTS happy_hog" ))
39+ conn .execute (text ("CREATE SCHEMA IF NOT EXISTS shakespeare" ))
40+ except ProgrammingError :
41+ # This might be raised by databases that don't support `IF NOT EXISTS`
42+ # and the schema already exists. You can choose to ignore it.
43+ pass
44+
2245
2346@pytest .fixture (scope = "session" )
2447async def start_db ():
25- async with engine .begin () as conn :
48+ # The `engine` is configured for the default 'postgres' database.
49+ # We connect to it and create the test database.
50+ # A transaction block is not used, as CREATE DATABASE cannot run inside it.
51+ async with engine .connect () as conn :
52+ await conn .execute (text ("COMMIT" )) # Ensure we're not in a transaction
53+ await conn .run_sync (_create_db )
54+
55+ # Now, connect to the newly created `testdb` with `test_engine`
56+ async with test_engine .begin () as conn :
57+ await conn .run_sync (_create_db_schema )
2658 await conn .run_sync (Base .metadata .drop_all )
2759 await conn .run_sync (Base .metadata .create_all )
2860 # for AsyncEngine created in function scope, close and
2961 # clean-up pooled connections
3062 await engine .dispose ()
63+ await test_engine .dispose ()
3164
3265
3366@pytest .fixture (scope = "session" )
@@ -40,5 +73,6 @@ async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
4073 headers = {"Content-Type" : "application/json" },
4174 transport = transport ,
4275 ) as test_client :
76+ app .dependency_overrides [get_db ] = get_test_db
4377 app .redis = await get_redis ()
4478 yield test_client
0 commit comments