|
3 | 3 | import pytest |
4 | 4 | from fastapi import Depends, FastAPI, Request |
5 | 5 | from fastapi.responses import HTMLResponse |
6 | | -from sqlalchemy.orm import Session |
| 6 | +from sqlalchemy.orm import Session, sessionmaker |
7 | 7 |
|
8 | 8 | from ...testclient import TestClient |
9 | 9 | from .crud import create_user, get_user |
10 | | -from .database import Base, SessionLocal, engine |
| 10 | +from .database import Base, engine |
11 | 11 |
|
12 | 12 | Base.metadata.create_all(bind=engine) |
13 | 13 |
|
14 | 14 |
|
15 | | -def get_db() -> t.Generator: |
16 | | - db = SessionLocal() |
17 | | - try: |
18 | | - yield db |
19 | | - finally: |
20 | | - db.close() |
| 15 | +@pytest.fixture |
| 16 | +def get_db(session_options: dict[str, t.Any]) -> t.Callable: |
| 17 | + SessionLocal = sessionmaker(**session_options or {"bind": engine}) |
| 18 | + |
| 19 | + def f() -> t.Generator: |
| 20 | + db = SessionLocal() |
| 21 | + try: |
| 22 | + yield db |
| 23 | + finally: |
| 24 | + db.close() |
| 25 | + |
| 26 | + return f |
21 | 27 |
|
22 | 28 |
|
23 | 29 | @pytest.fixture |
24 | | -def client(app: FastAPI, get_index: t.Callable) -> TestClient: |
| 30 | +def client(app: FastAPI, get_index: t.Callable, get_db: t.Callable) -> TestClient: |
25 | 31 | @app.get("/sql", response_class=HTMLResponse) |
26 | 32 | async def get_sql(request: Request, db: Session = Depends(get_db)) -> HTMLResponse: |
27 | | - user = create_user(db=db, username="test") |
| 33 | + user = create_user(db=db, username=str(id(get_db))) |
28 | 34 | get_user(db=db, user_id=user.id) |
29 | 35 | get_user(db=db, user_id=user.id) |
30 | 36 | return get_index(request) |
|
0 commit comments