I tried to make asynhronous code and it is working fine (Python 3.12), but I struggle with tests.
Here is function that i want to test:
async def get_repositories(
gl_client: gitlab.client.Gitlab,
marks: vng_release_notes_obj.Marks,
tag_pattern: re.Pattern[str],
repositories_info: Iterable[RepositoryBasicInfo],
) -> Optional[dict[pathlib.PurePosixPath, Repository]]:
repositories: dict[pathlib.PurePosixPath, Repository] = {}
semaphore = asyncio.Semaphore(4)
tasks = [
get_repository(semaphore, gl_client, marks, tag_pattern, repository_info)
for repository_info in repositories_info
]
for task in asyncio.as_completed(tasks):
repository = await task
if repository is None:
return None
repositories[repository.path] = repository
return repositories
And test, there is problem:
- mock properly returns (var: task) from get_repository (Coroutine[Any, Any, Repository | None]), so I can then check if it used as args in .as_completed
- another problem is correctly yield from
await task
to return right variable (repository or None).
class TestGetRepositories(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
"""Call before each test in this class."""
self.last_modification = LastModification(
"", datetime.datetime.now(), ""
)
self.gl_client = unittest.mock.AsyncMock()
self.marks = unittest.mock.MagicMock(spec_set=Marks)
self.tag_pattern = re.compile("dd")
self.repository_info_a = RepositoryBasicInfo(
"a", pathlib.PurePosixPath("a")
)
self.repository_info_b = RepositoryBasicInfo(
"b", pathlib.PurePosixPath("b")
)
self.repository_info_c = RepositoryBasicInfo(
"c", pathlib.PurePosixPath("c")
)
self.repository_info = [self.repository_info_a, self.repository_info_b, self.repository_info_c]
class AwaitableMock(unittest.mock.AsyncMock):
def __await__(self) -> typing.Iterator[typing.Any]:
self.await_count += 1
print(self.return_value)
return iter([self.return_value.return_value])
@unittest.mock.patch("asyncio.as_completed")
@unittest.mock.patch("asyncio.Semaphore")
@unittest.mock.patch("get_repository")
async def test_invalid_case(
self,
mock_get_repository: unittest.mock.AsyncMock,
mock_semaphore: unittest.mock.MagicMock,
mock_as_completed: unittest.mock.MagicMock,
) -> None:
"""Test invalid case."""
mock_semaphore_obj = unittest.mock.MagicMock()
mock_semaphore.return_value = mock_semaphore_obj
repository = unittest.mock.AsyncMock(spec=vng_release_notes.repository.Repository)
repository.path = pathlib.PurePosixPath("a")
#Here should return None in second variable
as_compmleted_ret_a = self.AwaitableMock(return_value=repository)
as_compmleted_ret_b = self.AwaitableMock(return_value=None)
as_compmleted_ret_c = self.AwaitableMock()
mock_as_completed.return_value = [as_compmleted_ret_a, as_compmleted_ret_b, as_compmleted_ret_c]
self.assertIsNone(
await vng_release_notes.gl.repository.get_repositories(
self.gl_client, self.marks, self.tag_pattern, self.repository_info
)
)
self.assertEqual(mock_get_repository.call_count, 3)
self.assertEqual(len(mock_get_repository.call_args_list[0].args), 5)
self.assertIs(mock_get_repository.call_args_list[0].args[0], mock_semaphore_obj)
self.assertIs(mock_get_repository.call_args_list[0].args[1], self.gl_client)
self.assertIs(mock_get_repository.call_args_list[0].args[2], self.marks)
self.assertIs(mock_get_repository.call_args_list[0].args[3], self.tag_pattern)
self.assertIs(mock_get_repository.call_args_list[0].args[4], self.repository_info_a)
self.assertEqual(len(mock_get_repository.call_args_list[1].args), 5)
self.assertIs(mock_get_repository.call_args_list[1].args[0], mock_semaphore_obj)
self.assertIs(mock_get_repository.call_args_list[1].args[1], self.gl_client)
self.assertIs(mock_get_repository.call_args_list[1].args[2], self.marks)
self.assertIs(mock_get_repository.call_args_list[1].args[3], self.tag_pattern)
self.assertIs(mock_get_repository.call_args_list[1].args[4], self.repository_info_b)
self.assertEqual(len(mock_get_repository.call_args_list[1].args), 5)
self.assertIs(mock_get_repository.call_args_list[2].args[0], mock_semaphore_obj)
self.assertIs(mock_get_repository.call_args_list[2].args[1], self.gl_client)
self.assertIs(mock_get_repository.call_args_list[2].args[2], self.marks)
self.assertIs(mock_get_repository.call_args_list[2].args[3], self.tag_pattern)
self.assertIs(mock_get_repository.call_args_list[2].args[4], self.repository_info_c)
self.assertEqual(mock_semaphore.call_count, 1)
self.assertIs(mock_semaphore.call_args[0][0], 4)
self.assertEqual(mock_as_completed.call_count, 1)
@unittest.mock.patch("asyncio.as_completed")
@unittest.mock.patch("asyncio.Semaphore")
@unittest.mock.patch(
"get_repository", new_callable=unittest.mock.AsyncMock
)
async def test_valid_case(
self,
mock_get_repository: unittest.mock.AsyncMock,
mock_semaphore: unittest.mock.MagicMock,
mock_as_completed: unittest.mock.MagicMock,
) -> None:
"""Test valid case.
Args:
mock_get_repository: Mock for get_repository function.
mock_semaphore: Mock for asyncio.Semaphore object.
mock_as_completed: Mock for asyncio.as_completed function.
"""
mock_semaphore_obj = unittest.mock.MagicMock()
mock_semaphore.return_value = mock_semaphore_obj
repository_a_path = pathlib.PurePosixPath("a")
repository_a = unittest.mock.MagicMock(spec=vng_release_notes.repository.Repository)
repository_a.path = repository_a_path
repository_b_path = pathlib.PurePosixPath("b")
repository_b = unittest.mock.MagicMock(spec=vng_release_notes.repository.Repository)
repository_b.path = repository_b_path
repository_c_path = pathlib.PurePosixPath("c")
repository_c = unittest.mock.MagicMock(spec=vng_release_notes.repository.Repository)
repository_c.path = repository_c_path
as_compmleted_ret_a = self.AwaitableMock(side_effect=repository_a)
as_compmleted_ret_b = self.AwaitableMock(side_effect=repository_b)
as_compmleted_ret_c = self.AwaitableMock(side_effect=repository_c)
mock_as_completed.return_value = [as_compmleted_ret_a, as_compmleted_ret_b, as_compmleted_ret_c]
output = await vng_release_notes.gl.repository.get_repositories(
self.gl_client, self.marks, self.tag_pattern, self.repository_info
)
if output is None:
self.fail("Should not be None")
self.assertIsInstance(output, dict)
self.assertEqual(len(output), 3)
self.assertIs(output[repository_a_path], repository_a)
self.assertIs(output[repository_b_path], repository_b)
self.assertIs(output[repository_c_path], repository_c)
self.assertEqual(mock_get_repository.call_count, 3)
self.assertEqual(len(mock_get_repository.call_args_list[0].args), 5)
self.assertIs(mock_get_repository.call_args_list[0].args[0], mock_semaphore_obj)
self.assertIs(mock_get_repository.call_args_list[0].args[1], self.gl_client)
self.assertIs(mock_get_repository.call_args_list[0].args[2], self.marks)
self.assertIs(mock_get_repository.call_args_list[0].args[3], self.tag_pattern)
self.assertIs(mock_get_repository.call_args_list[0].args[4], self.repository_info_a)
self.assertEqual(len(mock_get_repository.call_args_list[1].args), 5)
self.assertIs(mock_get_repository.call_args_list[1].args[0], mock_semaphore_obj)
self.assertIs(mock_get_repository.call_args_list[1].args[1], self.gl_client)
self.assertIs(mock_get_repository.call_args_list[1].args[2], self.marks)
self.assertIs(mock_get_repository.call_args_list[1].args[3], self.tag_pattern)
self.assertIs(mock_get_repository.call_args_list[1].args[4], self.repository_info_b)
self.assertEqual(len(mock_get_repository.call_args_list[1].args), 5)
self.assertIs(mock_get_repository.call_args_list[2].args[0], mock_semaphore_obj)
self.assertIs(mock_get_repository.call_args_list[2].args[1], self.gl_client)
self.assertIs(mock_get_repository.call_args_list[2].args[2], self.marks)
self.assertIs(mock_get_repository.call_args_list[2].args[3], self.tag_pattern)
self.assertIs(mock_get_repository.call_args_list[2].args[4], self.repository_info_c)
self.assertEqual(mock_semaphore.call_count, 1)
self.assertIs(mock_semaphore.call_args[0][0], 4)
self.assertEqual(mock_as_completed.call_count, 1)
Thank you for your suggestions.
I checked several other questions on stackoverflow, but it did not help.
New contributor
user25184856 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.