Relatively new to tensorflow here, and I am facing an issue where I have not yet managed to find a good answer through searching. So here goes:
I am trying to understand why applying a filter function to my tensorflow dataset, suddenly makes the time it takes to write the dataset to disk, seemingly explode. The time jumps from around 1 min to nearly two hours.
My code is Python 3.8, Tensorflow 2.11.1.
import tensorflow as tf
from skimage import filters
#Filtering function
def meijering_filter(x):
filtered = filters.meijering(x)
return filtered
#Import training data
training_dataset = tf.keras.utils.image_dataset_from_directory(
"path_to_training_dataset",
labels=None,
batch_size=5,
image_size=(480, 640),
shuffle=True,
seed=42,
subset='training',
validation_split=0.2,
color_mode='grayscale'
)
normalization_layer = tf.keras.layers.Rescaling(1./255)
normalized_train_dataset = training_dataset.map(lambda x: (normalization_layer(x)))
feat_training_dataset = normalized_train_dataset.map(lambda x: tf.numpy_function(meijering_filter, [x], tf.float32))
#Reshaping data, since the numpy_function() returns tensors with an unknown shape
data_reshape = tf.keras.Sequential([tf.keras.layers.Input(shape=(480, 640, 1))])
feat_training_dataset = feat_training_dataset.map(lambda x: (data_reshape(x)))
#Saving tensorflow dataset for later consumption
feat_training_dataset.save("save_path_on_disk")
I am aware that the save()
requires at least one compute of the dataset, and that the meijering filter is somewhat compute intensive. Still, timing this based on a take(1)
on my dataset, I expect that computation to take a few minutes (the take(1)
I timed to 0.12s, and my entire dataset is only around 1500 images).
I also tried not applying the data_reshape()
, but that did not make a notable difference.
Can someone help me understand why the execution of the save()
takes nearly two hours in the code above, and is there a way to remedy this?
3