I am writing a part of a project, which focuses on partitioning a dataset and distribute each chunk, together with a global machine learning model to each worker node in a distributed training environment. I am using Pytorch for this. I am having this error when I ran the unit test, but I have not been able to figure out why. Can anyone help me?
data_partitioning.py
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import Module
import torch.distributed as dist
import io
def partition_dataset(dataset: Dataset, num_workers: int) -> list[Dataset]:
"""
Parameters:
- dataset: The dataset to be partitioned.
- num_workers: The number of partitions to create, typically the number of workers.
Returns:
- List of datasets, one for each worker.
"""
if len(dataset) == 0:
raise ValueError("The dataset is empty")
if num_workers <= 0:
raise ValueError("Invalid number of workers")
partition_size = len(dataset) // num_workers
partitions = [partition_size] * num_workers
remainder = len(dataset) - partition_size * num_workers
if remainder > 0:
partitions[-1] += remainder
return random_split(dataset, partitions)
def master_node_distribute(backend: str, init_method: str, model: Module, dataset: Dataset, world_size: int):
"""
Master node function to distribute the model and dataset partitions to worker nodes.
Parameters:
- backend and init_method: Parameters for init_process_group
- model: The global model to be trained.
- dataset: The full dataset to be partitioned.
- world_size: The total number of workers.
"""
dist.init_process_group(backend=backend, init_method=init_method, rank=0, world_size=world_size)
# partition data
partitions = partition_dataset(dataset, world_size)
# serialize model
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
buffer.seek(0)
model_tensor = torch.tensor(bytearray(buffer.read()), dtype=torch.uint8)
for rank in range(1, world_size):
# send data
data_tensor = torch.tensor(partitions[rank], dtype=torch.float32)
dist.send(tensor=data_tensor, dst=rank)
# send model
dist.send(tensor=model_tensor, dst=rank)
data_partitioning_unit_test.py
import unittest
from unittest.mock import patch, Mock, call
import torch
from data_partitioning import master_node_distribute
from torch.utils.data import Dataset
from torch.nn import Module
import torch
class MockDataset(Dataset):
def __init__(self, size):
self.data = list(range(size))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class MockModel(Module):
def __init__(self):
super(MockModel, self).__init__()
self.layer = torch.nn.Linear(10, 1)
class TestDataPartitioning(unittest.TestCase):
@patch('data_partitioning.random_split')
@patch('data_partitioning.dist')
@patch('data_partitioning.io.BytesIO')
@patch('data_partitioning.torch')
@patch('torch.distributed.init_process_group')
def test_master_node_distribute(self, mock_init_process_group, mock_torch, mock_bytes_io, mock_dist, mock_random_split):
mock_init_process_group.return_value = None
model = MockModel()
dataset = MockDataset(100)
world_size = 4
backend='gloo'
init_method=""
# Mock partitions
partitions = [MockDataset(25) for _ in range(world_size)]
# Mock buffer
buffer = Mock()
mock_bytes_io.return_value = buffer
model_state = Mock()
mock_torch.save.return_value = model_state
buffer.read.return_value = b'model_state_data'
with patch('data_partitioning.random_split', return_value=partitions):
master_node_distribute(backend, init_method, model, dataset, world_size)
# Check that data and model were sent to each worker
expected_data_calls = [
call(torch.tensor(partitions[i], dtype=torch.float32), dst=i)
for i in range(1, world_size)
]
expected_model_call = [
call(torch.tensor(bytearray(buffer.read()), dtype=torch.uint8), dst=i)
for i in range(1, world_size)
]
expected_calls = expected_data_calls + expected_model_call * (world_size - 1)
mock_dist.send.assert_has_calls(expected_calls, any_order=True)
if __name__ == '__main__':
unittest.main()
Error when running:
raise AssertionError(
AssertionError: 'send' does not contain all of (call(tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]), dst=1), call(tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]), dst=2), call(tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]), dst=3), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=1), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=2), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=3), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=1), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=2), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=3), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=1), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=2), call(tensor([109, 111, 100, 101, 108, 95, 115, 116, 97, 116, 101, 95, 100, 97,
116, 97], dtype=torch.uint8), dst=3)) in its call list, found [call(tensor=<MagicMock name='torch.tensor()' id='4696420848'>, dst=1), call(tensor=<MagicMock name='torch.tensor()' id='4696420848'>, dst=1), call(tensor=<MagicMock name='torch.tensor()' id='4696420848'>, dst=2), call(tensor=<MagicMock name='torch.tensor()' id='4696420848'>, dst=2), call(tensor=<MagicMock name='torch.tensor()' id='4696420848'>, dst=3), call(tensor=<MagicMock name='torch.tensor()' id='4696420848'>, dst=3)] instead
----------------------------------------------------------------------
Ran 1 test in 0.006s