I am writing a custom image data loading function to randomly crop part of a large image according to its binary mask. The function will be used in PyTorch dataloader so I want it to be as fast and memory-efficient as possible. The image and the mask are quite large, as both the width and height are on the order of 10k~20k pixels.
I want every crop of the image to contain at last one positive point in the binary mask. My current sollution is to first randomly sample one positive point from the mask image, and then generate a crop box around it. The implementation contains a section of code as follows:
import PIL
import numpy as np
... # Some preprocessing to find all mask and image files
mask = PIL.Image.open(mask_file) # both the width and height have 10k~20k pixels.
# fast_pil_to_numpy: https://uploadcare.com/blog/fast-import-of-pillow-images-to-numpy-opencv-arrays/
mask_np = fast_pil_to_numpy(mask).astype(bool) # dim: [height, width]
mask_loc = np.where(mask_np) # get (loc_y, loc_x) of all positive indices
idx = np.random.randint(low=0, high=len(mask_loc[0]))
... # Generate a crop box around (mask_loc[1][idx], mask_loc[0][idx])
After profiling the entire function with line_profiler, I find that the line mask_loc = np.where(mask_np) is one of the performance bottle-necks. How can I optimize this part? Is there another more efficient way to randomly sample one positive point from a binary image?
3
The best method is to randomly sample a pixel in the mask, and if it’s not set, try again. You can try thousands of times before you get to the cost of enumerating all set pixels.
If your mask is very, very sparse, then your current method is likely best.
If your mask is a small and compact region in the larger image, getting the bounding box and sampling only random pixels in that box would be a speedup.
If your masks do not change over time, meaning they will be the same for multiple runs of what I assume are the training epochs of an ML model (since you mention the use with PyTorch dataloaders), and if the positive points are rather sparse, you might want to consider storing the indexes of your positive mask points (as lists of 2-tuples, N×2 Numpy arrays, etc., where each 2-tuple/row is the (x, y)-coordinate of one mask point). This might take one expensive preprocessing run, but after that, you can sample from these lists and are guaranteed to get a positive point each time.
It pretty much depends on how sparse your mask points are, whether sampling the indexes or sampling the mask directly repeatedly (as proposed in Cris Luengo’s answer) is faster.
Here is some code for timing both approaches, trying to sample a single positive mask point:
import matplotlib.pyplot as plt
import numpy as np
from timeit import Timer
from tqdm import tqdm # To monitor progress
h, w = 1000, 1000
rand = np.random.default_rng(seed=42)
num_timings = 100
def sample_from_mask(mask_):
while True:
r, c = rand.integers(h), rand.integers(w)
if mask_[r, c]:
return r, c
def sample_from_idxs(mask_idxs_):
return rand.choice(mask_idxs_)
timings_mask, timings_idxs = [], []
num_positives = np.round(np.logspace(1, np.log10(h * w), num=10)).astype(int)
for n in tqdm(num_positives):
# Set up the positive indexes and the corresponding mask
mask_idxs = rand.choice(np.mgrid[0:h, 0:w].reshape(2, -1).T, size=n, replace=False) # N×2
mask = np.zeros((h, w))
mask[tuple(mask_idxs.T)] = 1
# Time the sampling
timings_mask.append(Timer(lambda: sample_from_mask(mask)).timeit(num_timings))
timings_idxs.append(Timer(lambda: sample_from_idxs(mask_idxs)).timeit(num_timings))
plt.loglog()
plt.xlabel("density of positive pixels")
plt.ylabel("time to sample one positive pixel (s)")
density_positives = np.divide(num_positives, h * w)
plt.plot(density_positives, timings_mask, "x", label="sample_from_mask()")
plt.plot(density_positives, timings_idxs, "x", label="sample_from_idxs()")
plt.legend()
And here is the result I get (the y axis label is wrong here: the result for num_timings = 100, thus for sampling 100 positive pixels, is shown):
The crossing point for my implementations is around a density of 28% positive pixels, below which sampling indexes is faster. Note that this doesn’t include the step of getting the mask indexes from the mask (i.e. the preprocessing run that I mention above), which would shift the result in favor of the repeated mask sampling approach.
1
For this kind of computation I found NumPy rather slow compared to PyTorch (on CPU): for a 5000×4000 (20M pixels) input image, PyTorch is almost 10 times faster than NumPy on my computer (Apple M1 Pro CPU).
You should be careful though, PyTorch seems to fully utilize the CPU cores while NumPy does not. This means that if you load multiple samples concurrently using PyTorch DataLoader you may not see a speed improvement since multiple thread will fight for the same CPU ressource. However, this may still overlap some I/O (image loading) with computations, thus it may be worth it to give it a try.
Here is the code to reproduce the benchmark:
import numpy as np
import torch
from tqdm import tqdm
def main():
N = 100
H, W = 4000, 5000
# NumPy
mask_np = np.random.randn(H, W) > 0.0
for _ in tqdm(range(N)):
_ = np.where(mask_np)
# PyTorch
mask_torch = torch.from_numpy(mask_np)
for _ in tqdm(range(N)):
_ = torch.where(mask_torch)
# Check consistency
y_np, x_np = np.where(mask_np)
y_torch, x_torch = torch.where(mask_torch)
assert np.all(y_np == y_torch.numpy())
assert np.all(x_np == x_torch.numpy())
if __name__ == "__main__":
main()
3
