I’m trying to run a test suite on FastAPI endpoints that use a database connection. I’m using FastAPI’s dependency injection on the endpoint, which I attempt to override with a test database during testing.
# app/dependencies.py
from typing import Annotated, AsyncGenerator
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from fastapi import Depends
from app.settings import settings
async def get_db() -> AsyncGenerator[AsyncSession, None]:
db_config = ConnectionConfig(
service="mysql",
driver="asyncmy",
user=settings.MYSQL_USER,
password=settings.MYSQL_PASSWORD,
host=settings.MYSQL_HOST,
database=settings.MYSQL_DATABASE,
)
engine = create_async_engine(
get_db_url(config=db_config), echo=settings.DEVELOPMENT
)
async with AsyncSession(engine) as session:
yield session
# test_db.py
from typing import AsyncGenerator, Iterator
from contextlib import asynccontextmanager
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from app.dependencies import get_db
@pytest.fixture
async def session() -> AsyncGenerator[AsyncSession, None]:
engine = create_async_engine("sqlite+aiosqlite://")
async with engine.connect() as cxn, cxn.begin() as txn:
await cxn.run_sync(SQLModel.metadata.create_all)
async_session = AsyncSession(
bind=cxn, join_transaction_mode="create_savepoint"
)
yield async_session
await txn.rollback()
await engine.dispose()
@pytest.fixture
def client(session: AsyncSession):
def get_session_override():
return session
app = create_app()
app.dependency_overrides[get_db] = get_session_override
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
@pytest.mark.anyio
def test_create_attestation(
attestations: Iterator[Attestation], client: TestClient
):
for attestation in attestations:
payload = {"topic": "some topic"}
resp = client.post(url="/v1/upload", json=payload)
assert resp.status_code == status.HTTP_200_OK
# app/routers.py
@router.post(
path="/upload",
)
async def upload(
req: UploadRequest, db: AsyncSession = Depends(get_db)
) -> AttestationResponse:
row = Row(**req.model_dump())
db.add(row)
await db.commit()
When running the endpoints directly, the dependency injection works, but in the test suite, I get the following error:
AttributeError: 'async_generator' object has no attribute 'add'
I’ve tried to modify the fixture scopes, as well as trying to iterate through the async_generator using __anext__
, but I’m still getting the same error. Any help would be appreciated.