I have an ordered dataset(shuffle=False) that is categorised into “bins”. I shall present an example on smaller scale that helps to clarify. Let’s say the size of dataset is 60 with bins of sizes 10,20,30. I want to train my model in the order of bins. (first with 10 then 20 and 30). I want my DataLoader to get data in batch_sizes of 8. In this case, after getting the first 8 datapoints
, I don’t want to get the 2 remaining from bin-1 and get 6 from next one. What I want is to get only 2 and in the next iteration, get the 8 from bin-2. In short, I want to complete training in one bin first before moving to other. Also if batch_size happens to be greater than bin size, I want to get data in solely one bin before moving to next.
Can I please get some advice on how to do this? I could think of two ways: implementing a custom DataLoader(need advice on this too) or just create separate DataLoaders for each bin and while iterating with bins in the outermost loop, grab the corresponding DataLoader and do training. Will the latter method have some serious downsides?
I could think of two ways: implementing a custom DataLoader(need advice on this too) or just create separate DataLoaders for each bin and while iterating with bins in the outermost loop, grab the corresponding DataLoader and do training. Will the latter method have some serious downsides?
-
Second option definitely is easier, the only problem I can see is if you use several workers ie if you use multiprocessing:
what is going to happen is that each of those workers will spawn its own worker processes. Thus you will end up with a bunch of processes, most of them idle (those from the dataloader you are not using). If you have a lot of these dataloaders, then it might be a performance problem.
If you are not planning on using multiprocessing though, I would probably go with that -
First option is going to be way harder: here is the code for dataloader, you would want to make a subclass of that, but it gets pretty technical
-
There is also a third option, which is much simpler, although somewhat hacky to use: *make your
Dataset
class handle the batching, instead of using theDataloader
‘s collate capability:
Since your Dataset is ordered, just make it iterate over batches, and setcollate=None
in the dataloader. -
Finally a fourth option, the cleanest option is to use the
batch_sampler
capability of the torch dataloader (see documentation). Just build a Sampler yielding the indices of the batches: here is a version that follows your example:
from torch.utils.data import Sampler
from typing import List
from copy import copy
class BinSampler(torch.utils.data.Sampler[List[int]]):
bin_sizes: List[int]
batch_size: int
def __init__(self, bin_sizes: List[int], batch_size: int):
self.bin_sizes = bin_sizes
self.batch_size = batch_size
def __len__(self):
return sum(self.bin_sizes)
def __iter__(self):
bin_sizes = copy(self.bin_sizes)
remaining_in_bin = bin_sizes.pop()
current_index = 0
while bin_sizes or remaining_in_bin:
if remaining_in_bin == 0:
remaining_in_bin = bin_sizes.pop()
if remaining_in_bin < self.batch_size:
n_to_yield = remaining_in_bin
else: n_to_yield = self.batch_size
next_index = current_index + n_to_yield
yield range(current_index, next_index)
remaining_in_bin -= n_to_yield
current_index = next_index
dataloader = torch.utils.data.DataLoader(
dataset = your_dataset,
batch_sampler = BinSampler([10, 20, 30,], batch_size=8)
)