|
1 | 1 | from collections.abc import AsyncGenerator |
| 2 | +from types import SimpleNamespace |
2 | 3 | from typing import Any |
3 | 4 |
|
4 | 5 | import pytest |
5 | 6 | from httpx import ASGITransport, AsyncClient |
6 | 7 | from sqlalchemy import text |
7 | 8 | from sqlalchemy.exc import ProgrammingError |
8 | | -from sqlalchemy.ext.asyncio import AsyncSession |
9 | 9 |
|
10 | | -from app.database import engine, get_db, test_engine |
| 10 | +from app.database import engine, test_engine, get_test_db, get_db |
11 | 11 | from app.main import app |
12 | 12 | from app.models.base import Base |
13 | 13 | from app.redis import get_redis |
|
22 | 22 | def anyio_backend(request): |
23 | 23 | return request.param |
24 | 24 |
|
25 | | - |
26 | 25 | def _create_db(conn) -> None: |
27 | 26 | """Create the test database if it doesn't exist.""" |
28 | 27 | try: |
@@ -65,43 +64,16 @@ async def start_db(): |
65 | 64 | await test_engine.dispose() |
66 | 65 |
|
67 | 66 |
|
68 | | -@pytest.fixture(scope="function") |
69 | | -async def db_session() -> AsyncGenerator[AsyncSession, Any]: |
70 | | - """ |
71 | | - Provide a transactional database session for each test function. |
72 | | - Rolls back changes after the test. |
73 | | - """ |
74 | | - connection = await test_engine.connect() |
75 | | - transaction = await connection.begin() |
76 | | - session = AsyncSession(bind=connection) |
77 | | - |
78 | | - yield session |
79 | | - |
80 | | - await session.close() |
81 | | - await transaction.rollback() |
82 | | - await connection.close() |
83 | | - |
84 | | - |
85 | | -@pytest.fixture(scope="function") |
86 | | -async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]: |
87 | | - """ |
88 | | - Provide a test client for making API requests. |
89 | | - Uses the function-scoped db_session for test isolation. |
90 | | - """ |
91 | | - |
92 | | - def get_test_db_override(): |
93 | | - yield db_session |
94 | | - |
95 | | - app.dependency_overrides[get_db] = get_test_db_override |
96 | | - app.redis = await get_redis() |
97 | | - |
98 | | - transport = ASGITransport(app=app) |
| 67 | +@pytest.fixture(scope="session") |
| 68 | +async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 |
| 69 | + transport = ASGITransport( |
| 70 | + app=app, |
| 71 | + ) |
99 | 72 | async with AsyncClient( |
100 | 73 | base_url="http://testserver/v1", |
101 | 74 | headers={"Content-Type": "application/json"}, |
102 | 75 | transport=transport, |
103 | 76 | ) as test_client: |
| 77 | + app.dependency_overrides[get_db] = get_test_db |
| 78 | + app.redis = await get_redis() |
104 | 79 | yield test_client |
105 | | - |
106 | | - # Clean up dependency overrides |
107 | | - del app.dependency_overrides[get_db] |
0 commit comments