I’m currently working on a multi-label image classification problem where my dataset is significantly imbalanced. Each image in my dataset can have multiple labels. The labels are provided in a one-hot encoded format.
For example: [1,0,0,0,0,1,0] etc.
My train df:
Image Index Finding Labels
0 00005504_002.png Pleural_Thickening
1 00003527_002.png Atelectasis|Pneumonia
2 00018285_000.png Effusion|Mass
3 00016971_007.png Emphysema|Mass
4 00014022_071.png Atelectasis|Consolidation|Pleural_Thickening
To balance the dataset, I considered undersampling the overrepresented classes. However, I encountered a challenges that it might lose effectiveness of model.
I am looking to implement class weights/Sample weights in PyTorch. How can I effectively implement in PyTorch for this multi-label classification problem? I have read online that class weights might not work well with one-hot encoded labels and that using sample weights or a custom loss function might be necessary. How can i implement custom loss with weighted sampling ?. Any particular advice would be appreciated.
Below is my code:
`resnet50 = ResNet101(input_shape=(256, 256, 3), weights='imagenet', include_top=False)
for layer in resnet50.layers[:-3]:
layer.trainable = False
x = Flatten()(resnet50.output)
x = Dense(512, activation='relu')(x)
prediction = Dense(13, activation='sigmoid')(x)
model = Model(inputs=resnet50.input, outputs=prediction)
learning_rate = 0.001
adam_optimizer = Adam(learning_rate=learning_rate)
model.compile(optimizer=adam_optimizer, loss='binary_crossentropy', metrics=['accuracy', AUC(multi_label=True)])
early_stopping = EarlyStopping(monitor='val_auc', patience=5, restore_best_weights=True)
history = model.fit(train_dataset, epochs=100, validation_data=val_dataset, callbacks=[early_stopping])`
Amin is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.