I am trying to write pytest cases for below spanner DB call
class Database:
_client = None
def __init__(self, instance_id: str, database_id: str, pool=None):
if not Database._client:
Database._client = SpannerClient()
instance = Database._client.instance(instance_id)
self._database = instance.database(database_id, pool=pool)
def execute_query(
self, query: str, params: Dict | None = None, param_types: Dict | None = None
):
try:
with self._database.snapshot() as snapshot:
results = snapshot.execute_sql(query, params, param_types)
df = DataFrame(
data=[row for row in results],
columns=[col.name for col in results.fields],
)
return df
except GoogleAPICallError as e:
print(f"Error code:{e.code},Error message: {e.message}")
raise GoogleAPIError() from e
def get_database(self):
return self._database
Below is the pytest code that is currently written:
import pytest
from unittest.mock import patch, MagicMock
from database import Database
from google.api_core.exceptions import GoogleAPICallError
from pandas import DataFrame
@pytest.fixture
def mock_spanner_client():
with patch("database.Database._client") as MockClient:
yield MockClient
@pytest.fixture
def mock_instance(mock_spanner_client):
mock_instance = MagicMock()
mock_spanner_client.instance.return_value = mock_instance
yield mock_instance
@pytest.fixture
def mock_database(mock_instance):
mock_database = MagicMock()
mock_instance.database.return_value = mock_database
yield mock_database
def test_database_initialization(mock_spanner_client, mock_instance, mock_database):
db = Database("test_instance", "test_database")
assert db._database == mock_database
def test_get_database(mock_database):
db = Database("test_instance", "test_database")
assert db.get_database() == mock_database
def test_execute_query_success(mock_database):
mock_snapshot = MagicMock()
mock_snapshot.execute_sql.return_value = MagicMock(
__iter__=lambda self: iter([["row1"], ["row2"]]),
fields=[MagicMock(name="col1")],
)
mock_database.snapshot.return_valuezz = mock_snapshot
db = Database("test_instance", "test_database")
query = "SELECT * FROM test_table"
result = db.execute_query(query)
assert isinstance(result, DataFrame)
assert not result.empty
assert list(result.columns) == ["col1"]
Need help with mocking the execute sql part where spanner returns a StreamedResultSet Iterator.
mock_snapshot.execute_sql.return_value is giving a empty Dataframe which results in assertion error.