I have a complaint during the model train that may be caused by my code syntax. I want to create a mushroom image classification model that will produce 2 outputs to appear on mobile apps (tflite deploy to mobile apps). I created 2 classes for the output where there are types of mushrooms, namely Edible and Non-Edible Mushroom and also the name of the mushroom type which has 20 names of mushroom type names. In the code below, I used mobilnetV2 transfer learning :
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
# Load the MobileNetV2 model with pre-trained weights
IMG_SHAPE = (224, 224, 3)
base_model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=IMG_SHAPE)
# Freeze the base model
base_model.trainable = False
# Add custom layers on top of the base model
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
# Output layer for classifying mushroom type (20 classes)
mushroom_class = Dense(20, activation='softmax', name='mushroom_class')(x)
# Output layer for classifying edibility (2 classes)
edibility_output = Dense(2, activation='softmax', name='edibility_output')(x)
# edibility_output = Dense(2, activation='softmax', name='edibility_output')(x)
# Create the full model with two outputs
model = Model(inputs=base_model.input, outputs=[edibility_output, mushroom_class])
# Summary of the model
model.summary()
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import sparse_categorical_crossentropy
# Compile the model
model.compile(optimizer=Adam(),
loss={'edibility_output': 'binary_crossentropy', 'mushroom_class': 'categorical_crossentropy'},
metrics={'mushroom_class': 'accuracy', 'edibility_output': 'accuracy'})
I also created a custom generator to classify the types of mushrooms and the names of the types of mushrooms.
def custom_generator(generator):
while True:
x, y = generator.next()
# Split labels into two parts: 1-class edible and 20-class mushroom types
edible_labels = y[:, :2] # Assuming the first column is edible labels
mushroom_labels = y[:, 0:20] # The remaining columns are mushroom type labels
yield x, {'edibility_output': edible_labels, 'mushroom_class': mushroom_labels}
train_generator_custom = custom_generator(train_generator)
validation_generator_custom = custom_generator(validation_generator)
For the folder directory place, more or less like the image below:
I also declared ImageDataGenerator :
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Path to your training and validation directories
train_dir = 'mushroom3/MO_95/mushroom_dataset/train'
validation_dir = 'mushroom3/MO_95/mushroom_dataset/test'
train_ediblemushroom_dir = os.path.join(train_dir, 'Edible')
train_inediblemushroom_dir = os.path.join(train_dir, 'Non-Edible')
validation_ediblemushroom_dir = os.path.join(validation_dir, 'Edible')
validation_inediblemushroom_dir = os.path.join(validation_dir, 'Non-Edible')
# Define ImageDataGenerator for training data
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# validation_datagen = ImageDataGenerator(rescale=1./255)
validation_datagen = ImageDataGenerator(rescale=1./255)
# Generate batches of augmented data from the directories
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
batch_size=20,
class_mode='categorical' # Because you have multiple classes
)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(224, 224),
batch_size=20,
class_mode='categorical' # Because you have multiple classes
)
This is the code syntax for train model :
history = model.fit(
train_generator_custom,
steps_per_epoch=len(train_generator),
epochs=200,
validation_data=validation_generator_custom,
validation_steps=len(validation_generator)
)
When I want to train the model, the following error message appears :
InvalidArgumentError: Graph execution error:
........
........
Node: 'categorical_crossentropy/softmax_cross_entropy_with_logits'
logits and labels must be broadcastable: logits_size=[20,20] labels_size=[20,2]
[[{{node categorical_crossentropy/softmax_cross_entropy_with_logits}}]] [Op:__inference_train_function_8124]
Can anyone help with the error message in my script code?
I want the mobile apps to produce output for mushroom types and names of mushroom types with tflite after I create tflite.