66from sqlalchemy import text
77from sqlalchemy .exc import ProgrammingError
88
9- from app .database import engine , get_db , get_test_db , test_engine
9+ from app .database import TestAsyncSessionFactory , engine , get_db , test_engine
1010from app .main import app
1111from app .models .base import Base
1212from app .redis import get_redis
@@ -43,7 +43,7 @@ def _create_db_schema(conn) -> None:
4343 pass
4444
4545
46- @pytest .fixture (scope = "session" )
46+ @pytest .fixture (scope = "session" , autouse = True )
4747async def start_db ():
4848 # The `engine` is configured for the default 'postgres' database.
4949 # We connect to it and create the test database.
@@ -63,16 +63,36 @@ async def start_db():
6363 await test_engine .dispose ()
6464
6565
66- @pytest .fixture (scope = "session" )
67- async def client (start_db ) -> AsyncGenerator [AsyncClient , Any ]: # noqa: ARG001
66+ @pytest .fixture ()
67+ async def db_session ():
68+ connection = await test_engine .connect ()
69+ transaction = await connection .begin ()
70+ session = TestAsyncSessionFactory (bind = connection )
71+
72+ try :
73+ yield session
74+ finally :
75+ # Rollback the overall transaction, restoring the state before the test ran.
76+ await session .close ()
77+ if transaction .is_active :
78+ await transaction .rollback ()
79+ await connection .close ()
80+
81+
82+ @pytest .fixture (scope = "function" )
83+ async def client (db_session ) -> AsyncGenerator [AsyncClient , Any ]:
6884 transport = ASGITransport (
6985 app = app ,
7086 )
87+
88+ async def override_get_db ():
89+ yield db_session
90+
7191 async with AsyncClient (
7292 base_url = "http://testserver/v1" ,
7393 headers = {"Content-Type" : "application/json" },
7494 transport = transport ,
7595 ) as test_client :
76- app .dependency_overrides [get_db ] = get_test_db
96+ app .dependency_overrides [get_db ] = override_get_db
7797 app .redis = await get_redis ()
7898 yield test_client
0 commit comments