Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ POSTGRES_PORT=5432
POSTGRES_DB=devdb
POSTGRES_USER=devdb
POSTGRES_TEST_DB=testdb
POSTGRES_TEST_USER=testdb
POSTGRES_PASSWORD=secret

# Redis
Expand Down
1 change: 0 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Settings(BaseSettings):
POSTGRES_PASSWORD: str
POSTGRES_HOST: str
POSTGRES_DB: str
POSTGRES_TEST_USER: str
POSTGRES_TEST_DB: str

@computed_field
Expand Down
2 changes: 1 addition & 1 deletion app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
test_engine = create_async_engine(
global_settings.test_asyncpg_url.unicode_string(),
future=True,
echo=True,
echo=False,
)

# expire_on_commit=False will prevent attributes from being expired
Expand Down
19 changes: 13 additions & 6 deletions tests/api/test_stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from httpx import AsyncClient
from inline_snapshot import snapshot
from polyfactory.factories.pydantic_factory import ModelFactory
from sqlalchemy.ext.asyncio import AsyncSession

from app.schemas.stuff import StuffSchema
from app.models import Stuff

pytestmark = pytest.mark.anyio

Expand All @@ -14,7 +16,7 @@ class StuffFactory(ModelFactory[StuffSchema]):
__model__ = StuffSchema


async def test_add_stuff(client: AsyncClient):
async def test_add_stuff(client: AsyncClient, db_session: AsyncSession):
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
response = await client.post("/stuff", json=stuff)
assert response.status_code == status.HTTP_201_CREATED
Expand All @@ -32,22 +34,27 @@ async def test_add_stuff(client: AsyncClient):
)


async def test_get_stuff(client: AsyncClient):
async def test_get_stuff(client: AsyncClient, db_session: AsyncSession):
response = await client.get("/stuff/nonexistent")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == snapshot(
{"no_response": "The requested resource was not found"}
)
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
await client.post("/stuff", json=stuff)
name = stuff["name"]
# await client.post("/stuff", json=stuff)
# name = stuff["name"]
Comment thread
grillazz marked this conversation as resolved.
Outdated
stuff = Stuff(**stuff)
name = stuff.name
db_session.add(stuff)
await db_session.commit()

response = await client.get(f"/stuff/{name}")
assert response.status_code == status.HTTP_200_OK
assert response.json() == snapshot(
{
"id": IsUUID(4),
"name": stuff["name"],
"description": stuff["description"],
"name": stuff.name,
"description": stuff.description,
}
)

Expand Down
34 changes: 28 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Any

import pytest
from fastapi.exceptions import ResponseValidationError
Comment thread
grillazz marked this conversation as resolved.
Outdated
from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
Comment thread
grillazz marked this conversation as resolved.
Outdated

from app.database import engine, get_db, get_test_db, test_engine
from app.database import engine, get_db, test_engine, TestAsyncSessionFactory
from app.main import app
from app.models.base import Base
from app.redis import get_redis
Expand Down Expand Up @@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None:
pass


@pytest.fixture(scope="session")
@pytest.fixture(scope="session", autouse=True)
async def start_db():
# The `engine` is configured for the default 'postgres' database.
# We connect to it and create the test database.
Expand All @@ -63,16 +64,37 @@ async def start_db():
await test_engine.dispose()


@pytest.fixture(scope="session")
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
@pytest.fixture()
async def db_session():
connection = await test_engine.connect()
transaction = await connection.begin()
session = TestAsyncSessionFactory(bind=connection)

try:
yield session
finally:
# Rollback the overall transaction, restoring the state before the test ran.
await session.close()
if transaction.is_active:
await transaction.rollback()
await connection.close()


@pytest.fixture(scope="function")
async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
Comment thread
grillazz marked this conversation as resolved.
Outdated
transport = ASGITransport(
app=app,
)

async def override_get_db():
yield db_session
await db_session.commit()
Comment thread
grillazz marked this conversation as resolved.
Outdated

async with AsyncClient(
base_url="http://testserver/v1",
headers={"Content-Type": "application/json"},
transport=transport,
) as test_client:
app.dependency_overrides[get_db] = get_test_db
app.dependency_overrides[get_db] = override_get_db
app.redis = await get_redis()
yield test_client
Loading