Skip to content

Commit adecd85

Browse files
committed
rollback
1 parent ef6f9bc commit adecd85

2 files changed

Lines changed: 12 additions & 51 deletions

File tree

tests/api/test_auth.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,17 @@ async def test_add_user(client: AsyncClient):
3737

3838
# TODO: parametrize test with diff urls including 404 and 401
3939
async def test_get_token(client: AsyncClient):
40-
# First, create the user required for this test
41-
user_payload = {
42-
"email": "joe@grillazz.com",
43-
"first_name": "Joe",
44-
"last_name": "Garcia",
45-
"password": "s1lly",
46-
}
47-
create_user_response = await client.post("/user/", json=user_payload)
48-
assert create_user_response.status_code == status.HTTP_201_CREATED
49-
50-
# Now, request the token for the newly created user
51-
token_payload = {"email": "joe@grillazz.com", "password": "s1lly"}
40+
payload = {"email": "joe@grillazz.com", "password": "s1lly"}
5241
response = await client.post(
5342
"/user/token",
54-
data=token_payload,
43+
data=payload,
5544
headers={"Content-Type": "application/x-www-form-urlencoded"},
5645
)
5746
assert response.status_code == status.HTTP_201_CREATED
5847
claimset = jwt.decode(
5948
response.json()["access_token"], options={"verify_signature": False}
6049
)
61-
assert claimset["email"] == token_payload["email"]
50+
assert claimset["email"] == payload["email"]
6251
assert claimset["expiry"] == IsPositiveFloat()
6352
assert claimset["platform"] == "python-httpx/0.28.1"
6453

tests/conftest.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from collections.abc import AsyncGenerator
2+
from types import SimpleNamespace
23
from typing import Any
34

45
import pytest
56
from httpx import ASGITransport, AsyncClient
67
from sqlalchemy import text
78
from sqlalchemy.exc import ProgrammingError
8-
from sqlalchemy.ext.asyncio import AsyncSession
99

10-
from app.database import engine, get_db, test_engine
10+
from app.database import engine, test_engine, get_test_db, get_db
1111
from app.main import app
1212
from app.models.base import Base
1313
from app.redis import get_redis
@@ -22,7 +22,6 @@
2222
def anyio_backend(request):
2323
return request.param
2424

25-
2625
def _create_db(conn) -> None:
2726
"""Create the test database if it doesn't exist."""
2827
try:
@@ -65,43 +64,16 @@ async def start_db():
6564
await test_engine.dispose()
6665

6766

68-
@pytest.fixture(scope="function")
69-
async def db_session() -> AsyncGenerator[AsyncSession, Any]:
70-
"""
71-
Provide a transactional database session for each test function.
72-
Rolls back changes after the test.
73-
"""
74-
connection = await test_engine.connect()
75-
transaction = await connection.begin()
76-
session = AsyncSession(bind=connection)
77-
78-
yield session
79-
80-
await session.close()
81-
await transaction.rollback()
82-
await connection.close()
83-
84-
85-
@pytest.fixture(scope="function")
86-
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]:
87-
"""
88-
Provide a test client for making API requests.
89-
Uses the function-scoped db_session for test isolation.
90-
"""
91-
92-
def get_test_db_override():
93-
yield db_session
94-
95-
app.dependency_overrides[get_db] = get_test_db_override
96-
app.redis = await get_redis()
97-
98-
transport = ASGITransport(app=app)
67+
@pytest.fixture(scope="session")
68+
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
69+
transport = ASGITransport(
70+
app=app,
71+
)
9972
async with AsyncClient(
10073
base_url="http://testserver/v1",
10174
headers={"Content-Type": "application/json"},
10275
transport=transport,
10376
) as test_client:
77+
app.dependency_overrides[get_db] = get_test_db
78+
app.redis = await get_redis()
10479
yield test_client
105-
106-
# Clean up dependency overrides
107-
del app.dependency_overrides[get_db]

0 commit comments

Comments
 (0)