I’m using PyTorch’s DataLoader to load my dataset. I’ve noticed that my program hangs indefinitely during training when I set num_workers > 0. However, it works fine when num_workers = 0.
Here’s a simplified version of my code:
class MedianFilter:
def __init__(self, kernel_size=3):
self.kernel_size = kernel_size
def __call__(self, img):
return img
train_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
MedianFilter(),
transforms.RandomAffine(degrees=40, translate=(0.125, 0.125)),
transforms.RandomResizedCrop(size=(28, 28), scale=(1, 1), ratio=(1, 1), interpolation=InterpolationMode.BILINEAR),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
MedianFilter(),
transforms.ToTensor()
])
train_dataset = ImageFolder(root='../Dataset/Original/train/', transform=train_transform)
val_dataset = ImageFolder(root='C:../Dataset/Original/val/', transform=val_transform)
train_dataloader = DataLoader(train_dataset, batch_size=64, pin_memory=True, num_workers=3, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, pin_memory=True, num_workers=3)
dataloader = {'train':train_dataloader, 'val':val_dataloader}
def train_model(model, dataloader, criterion, optimizer, scheduler, num_epochs):
acc_history = {'train' : [], 'val' : []}
loss_history = {'train' : [], 'val' : []}
best_acc = 0.0
for epoch in range(1, num_epochs+1):
print(f"Epoch{epoch}:")
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_correct = 0
running_loss = 0.0
totalIm = 0
for data, _label in dataloader[phase]: # Stuck here
In this code, the MedianFilter
class is a simple identity function. Despite this, the program still hangs when num_workers > 0.
Why is this happening and how can I resolve this issue?
I have tried to simplify the MedianFilter
class to a simple identity function that just returns the input image. Despite this simplification, the program still hangs when num_workers > 0
. I expected that this change would resolve the issue, as the MedianFilter class is no longer doing any significant computation. However, the problem persists.
I have also tried running the code without the custom torchvision transform and setting num_workers > 0
. In this case, the code runs as expected.
Gway is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.