import tensorflow as tf
# Create dummy data
image_data = [tf.constant([1]), tf.constant([2]), tf.constant([3])]
caption_data = [tf.constant([10]), tf.constant([20]), tf.constant([30])]
target_data = [tf.constant([100]), tf.constant([200]), tf.constant([300])]
# Create a dataset from the dummy data
dataset = tf.data.Dataset.from_tensor_slices(((image_data, caption_data), target_data))
# Define a filter function that checks a simple condition on the target tensor
def filter_funct(data):
((image_tensor, caption_tensor), target_tensor) = data
return target_tensor > 150
# Apply the filter function
filtered_dataset = dataset.filter(filter_funct)
# Print the filtered dataset
for ((image_tensor, caption_tensor), target_tensor) in filtered_dataset:
print("Image Tensor:", image_tensor.numpy())
print("Caption Tensor:", caption_tensor.numpy())
print("Target Tensor:", target_tensor.numpy())
I wanted to slice for some corrupted images but I can’t use tf.filter correctly. Above is just a dummy data. it is not working.
this is the error code
TypeError: outer_factory.<locals>.inner_factory.<locals>.tf__filter_funct() takes 1 positional argument but 2 were given