|
1 | 1 | import bcrypt |
2 | 2 | import pytest |
3 | 3 | from fastapi.testclient import TestClient |
4 | | -from sqlalchemy.orm import Session |
| 4 | +from sqlalchemy import create_engine |
| 5 | +from sqlalchemy.orm import Session, sessionmaker |
5 | 6 |
|
6 | | -from app.core.database import Base, get_db, init_db |
| 7 | +from app.core.database import Base, get_db |
7 | 8 | from app.factory import create_app |
8 | 9 | from app.modules.auth.crud import create_user |
9 | 10 | from app.modules.auth.token import create_access_token |
10 | 11 | from app.settings import Settings |
11 | 12 |
|
| 13 | +# No need for manual load_dotenv, Settings() will handle it via resolve_env_file() |
| 14 | +# or we can pass it explicitly for maximum clarity in tests. |
| 15 | + |
12 | 16 |
|
13 | 17 | @pytest.fixture(scope="session") |
14 | 18 | def test_settings(): |
15 | | - """Session-scoped fixture for test settings, using a file-based SQLite DB.""" |
16 | | - settings = Settings(_env_file=".env.test") |
17 | | - settings.database_url = "sqlite:///./test.db" |
18 | | - return settings |
| 19 | + """Session-scoped fixture for test settings, using a Postgres DB.""" |
| 20 | + # We explicitly pass the env file to ensure we use exactly what we want |
| 21 | + return Settings(_env_file=".env.test") |
19 | 22 |
|
20 | 23 |
|
21 | 24 | @pytest.fixture(scope="session") |
22 | | -def db_engine_session(test_settings: Settings): |
23 | | - """Session-scoped fixture for the database engine and session factory.""" |
24 | | - engine, session_local = init_db(test_settings) |
25 | | - return engine, session_local |
| 25 | +def db_engine(test_settings: Settings): |
| 26 | + """Session-scoped engine for the Postgres database.""" |
| 27 | + engine = create_engine(test_settings.database_url) |
| 28 | + Base.metadata.create_all(bind=engine) |
| 29 | + yield engine |
| 30 | + Base.metadata.drop_all(bind=engine) |
26 | 31 |
|
27 | 32 |
|
28 | 33 | @pytest.fixture(scope="session") |
29 | | -def app(test_settings: Settings, db_engine_session): |
| 34 | +def app(test_settings: Settings, db_engine): |
30 | 35 | """Session-scoped fixture for the FastAPI application instance.""" |
31 | | - _, session_local = db_engine_session |
32 | | - return create_app(test_settings, session_local) |
| 36 | + # We use a dummy session_local here because we'll override get_db at the request level |
| 37 | + dummy_session_local = sessionmaker(bind=db_engine) |
| 38 | + return create_app(test_settings, dummy_session_local) |
33 | 39 |
|
34 | 40 |
|
35 | | -@pytest.fixture(scope="session") |
36 | | -def override_get_db(db_engine_session): |
37 | | - """Session-scoped fixture to override the `get_db` dependency.""" |
38 | | - _, session_local = db_engine_session |
| 41 | +@pytest.fixture |
| 42 | +def db(db_engine): |
| 43 | + """ |
| 44 | + Function-scoped fixture that provides a transactional database session. |
| 45 | + Everything is rolled back at the end of the test. |
| 46 | + """ |
| 47 | + connection = db_engine.connect() |
| 48 | + transaction = connection.begin() |
| 49 | + session = Session(bind=connection) |
39 | 50 |
|
40 | | - def _override_get_db(): |
41 | | - db: Session = session_local() |
42 | | - try: |
43 | | - yield db |
44 | | - finally: |
45 | | - db.close() |
| 51 | + yield session |
46 | 52 |
|
47 | | - return _override_get_db |
| 53 | + session.close() |
| 54 | + transaction.rollback() |
| 55 | + connection.close() |
48 | 56 |
|
49 | 57 |
|
50 | | -@pytest.fixture(scope="session", autouse=True) |
51 | | -def setup_database(app, db_engine_session, override_get_db): |
52 | | - """ |
53 | | - Session-scoped, autouse fixture to set up the database schema and dependency overrides. |
54 | | - """ |
55 | | - engine, _ = db_engine_session |
56 | | - Base.metadata.create_all(bind=engine) |
57 | | - app.dependency_overrides[get_db] = override_get_db |
58 | | - yield |
59 | | - Base.metadata.drop_all(bind=engine) |
| 58 | +@pytest.fixture |
| 59 | +def override_get_db(app, db): |
| 60 | + """Fixture to override the get_db dependency for every test.""" |
60 | 61 |
|
| 62 | + def _get_db(): |
| 63 | + yield db |
61 | 64 |
|
62 | | -@pytest.fixture(scope="session") |
63 | | -def test_user(override_get_db): |
64 | | - """Create a test user in the DB""" |
65 | | - db = next(override_get_db()) |
| 65 | + app.dependency_overrides[get_db] = _get_db |
| 66 | + yield _get_db # Yield the function so it can be called if needed |
| 67 | + app.dependency_overrides.pop(get_db, None) |
| 68 | + |
| 69 | + |
| 70 | +@pytest.fixture(autouse=True) |
| 71 | +def auto_override_get_db(override_get_db): |
| 72 | + """Automatically apply the get_db override for every test.""" |
| 73 | + pass |
| 74 | + |
| 75 | + |
| 76 | +@pytest.fixture |
| 77 | +def test_user(db): |
| 78 | + """Create a test user in the DB for the current transaction.""" |
66 | 79 | user = create_user( |
67 | 80 | db, |
68 | 81 | email="test@example.com", |
69 | 82 | hashed_pw=bcrypt.hashpw(b"password123", bcrypt.gensalt()).decode("utf-8"), |
70 | 83 | ) |
71 | | - |
72 | 84 | return user |
73 | 85 |
|
74 | 86 |
|
75 | | -@pytest.fixture(scope="session") |
| 87 | +@pytest.fixture |
76 | 88 | def access_token(test_user): |
77 | 89 | """Create an access token for the test user.""" |
78 | 90 | return create_access_token({"sub": str(test_user.email)}) |
79 | 91 |
|
80 | 92 |
|
81 | | -@pytest.fixture(scope="module") |
| 93 | +@pytest.fixture |
82 | 94 | def client(app): |
83 | | - """Module-scoped test client for unauthenticated requests.""" |
| 95 | + """Test client for unauthenticated requests.""" |
84 | 96 | with TestClient(app) as c: |
85 | 97 | yield c |
86 | 98 |
|
87 | 99 |
|
88 | 100 | @pytest.fixture |
89 | 101 | def auth_client(app, access_token): |
90 | | - """Function-scoped test client for authenticated requests.""" |
| 102 | + """Test client for authenticated requests.""" |
91 | 103 | with TestClient(app) as auth_client: |
92 | 104 | auth_client.headers.update({"Authorization": f"Bearer {access_token}"}) |
93 | 105 | yield auth_client |
0 commit comments