I have a code which basically calculates the mean of the channels of each channel of an image.
Yet the mean is a robust mean calculated between quantiles and iqnores zero pixels (all channels value is 0):
import numpy as np
import skimage as ski
from numba import njit
image = ski.io.imread("https://raw.githubusercontent.com/mikolalysenko/lena/master/lena.png")
def cal_img_stat( img: np.ndarray, low_quantile: float = 0.0, high_quantile: float = 1.0 ):
num_px_channel = img.shape[0] * img.shape[1]
chnl_matrix = np.reshape(np.transpose(img, (2, 0, 1)), (3, num_px_channel)) # bottleneck
non_zero_px_flag = np.any(chnl_matrix , axis=0) # ignore zero valued pixel
non_zero_px = chnl_matrix[:, non_zero_px_flag]
min_vals, max_vals = np.quantile(non_zero_px, [low_quantile, high_quantile], axis=1)
mask_px = np.logical_and(non_zero_px >= min_vals[:, None], non_zero_px <= max_vals [:, None])
mean_vals = np.mean(non_zero_px, axis = 1, where = mask_px)
return mean_vals
mean_vals = cal_img_stat(image, 0.05, 0.95)
I tried to optimize the run time of the code:
- apply the function along rows.
- Pack the channels to be contiguous in memory.
In addition, I tried applying the actual mean calculation using numba
hand crafted function. I also tried to write chnl_matrix = np.reshape(np.transpose(img, (2, 0, 1)), (3, num_px_channel))
using numba
yet with no improvement.
It seems the bottleneck is the conversion of the image to matrix.
I could do a faster conversion by np.reshape(img, (num_px_channel, 3))
yet the following operations will be slower.
Is there a better policy to apply here to get a better run time?