I want to test my GraphQL endpoint. To do this, I have to mock the database.
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from contextlib import asynccontextmanager
import os
class Database:
def __init__(self) -> None:
self.engine = None
def connect(self) -> None:
if not self.engine:
self.engine = create_async_engine(os.environ.get("DATABASE_URL", "sqlite:///db.sqlite")
)
else:
raise ConnectionAbortedError("Database is already connected")
@asynccontextmanager
async def get_async_session(self):
if not self.engine:
raise ConnectionError("Database is not connected")
async with AsyncSession(self.engine) as session:
try:
yield session
finally:
await session.close()
I’ve tried to mock create_engine
in my test,
async def test_register_user_mutation(self, gql_schema: Schema, engine):
with patch("sqlmodel.create_engine") as mock_create_engine:
mock_create_engine.return_value = engine
in the app
fixture
@pytest.fixture(scope="session", autouse=True)
def app() -> FastAPI:
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with patch("sqlmodel.create_engine") as mock_create_engine:
mock_create_engine.return_value = engine
from daytistics.main import app
return app
and I’ve to mock the Session
async def test_register_user_mutation(self, gql_schema: Schema, engine):
with patch("sqlmodel.Session", autospec=True) as mock_session_class:
mock_session_class.return_value = Session(engine)
but each time it was written to the production database.
I don’t if it’s useful, but this is my app factory:
def create_app() -> FastAPI:
dotenv.load_dotenv()
app = FastAPI()
register_dependencies()
with container.sync_context() as ctx:
db = ctx.resolve(Database)
db.connect()
graphql_router = create_graphql_router()
app.include_router(graphql_router, prefix="/graphql")
return app
app = create_app()
Update:
I’ve tried to use a monkeypatch and it did not work either:
@pytest.fixture(scope="function")
async def mock_db(monkeypatch):
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
future=True,
connect_args={"check_same_thread": False},
)
async with test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
def mock_connect(self):
self.engine = test_engine
@asynccontextmanager
async def mock_get_async_session(self):
async with AsyncSession(test_engine) as session:
try:
yield session
finally:
await session.close()
monkeypatch.setattr(Database, "connect", mock_connect)
monkeypatch.setattr(Database, "get_async_session", mock_get_async_session)
return test_engine
1