I want to train a model on the MedMNIST database, specifically RetinaMNIST.
I start by defining the data transforms and creating the dataloaders:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import os
from medmnist import RetinaMNIST
import numpy as np
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from tqdm import tqdm
from resnet_pytorch import ResNet
import argparse
import random
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
data_mean = [0.5]
data_std = [0.5]
# data transform
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std)
])
BATCH_SIZE = 100
trainset = RetinaMNIST(
split="train",
transform=data_transform,
size=224
)
valset = RetinaMNIST(
split="train",
transform=data_transform,
size=224
)
# change shape and type of labels to be the same as the CIFAR10 dataset (maybe this will fix something?)
newLabels = []
for label in trainset.labels:
newLabels.append(label[0])
trainset.labels = newLabels
newLabels = []
for label in valset.labels:
newLabels.append(label[0])
valset.labels = newLabels
# Create data loaders.
train_loader = DataLoader(
trainset,
batch_size=BATCH_SIZE,
shuffle=True
)
valid_loader = DataLoader(
valset,
batch_size=BATCH_SIZE,
shuffle=False
)
Note that I change the Labels from a numpy list of lists (where each label was a list of size one), to be a regular list. This is the same as the target array in the CIFAR10 dataset and I implemented this to try and fix the problem but to no avail.
i.e. the program works when I use this dataset in the dataloaders:
# CIFAR10 training dataset.
dataset_train = datasets.CIFAR10(
root='data',
train=True,
download=True,
transform=ToTensor(),
)
# CIFAR10 validation dataset.
dataset_valid = datasets.CIFAR10(
root='data',
train=False,
download=True,
transform=ToTensor(),
)
# Create data loaders.
train_loader = DataLoader(
dataset_train,
batch_size=BATCH_SIZE,
shuffle=True
)
valid_loader = DataLoader(
dataset_valid,
batch_size=BATCH_SIZE,
shuffle=False
)
I then define my training and validation loops:
# From https://debuggercafe.com/training-resnet18-from-scratch-using-pytorch/
def train(model, trainloader, optimizer, criterion, device):
model.train()
print('Training')
train_running_loss = 0.0
train_running_correct = 0
counter = 0
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
counter += 1
image, label = data
image = image.to(device)
label = label.to(device)
optimizer.zero_grad()
# Forward pass.
outputs = model(image)
# Calculate the loss.
print("label: ", type(label), label, label.shape)
print("outputs: ", type(outputs), outputs, outputs.shape)
loss = criterion(outputs, label)
train_running_loss += loss.item()
# Calculate the accuracy.
_, preds = torch.max(outputs.data, 1)
train_running_correct += (preds == label).sum().item()
# Backpropagation
loss.backward()
# Update the weights.
optimizer.step()
# Loss and accuracy for the complete epoch.
epoch_loss = train_running_loss / counter
# epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
return epoch_loss, epoch_acc
def validate(model, testloader, criterion, device):
model.eval()
print('Validation')
valid_running_loss = 0.0
valid_running_correct = 0
counter = 0
with torch.no_grad():
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
counter += 1
image, label = data
image = image.to(device)
label = label.to(device)
# Forward pass.
outputs = model(image)
# Calculate the loss.
loss = criterion(outputs, label)
valid_running_loss += loss.item()
# Calculate the accuracy.
_, preds = torch.max(outputs.data, 1)
valid_running_correct += (preds == label).sum().item()
# Loss and accuracy for the complete epoch.
epoch_loss = valid_running_loss / counter
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
return epoch_loss, epoch_acc
Import my model:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
Define other parametres of the training process:
# Learning and training parameters.
epochs = 20
batch_size = 100
learning_rate = 0.01
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# Optimizer.
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# Loss function.
criterion = nn.CrossEntropyLoss()`
Lastly the whole loop:
from IPython.display import clear_output
train_acc, valid_acc = [], []
train_loss, valid_loss = [], []
# Start the training.
for epoch in range(epochs):
clear_output(wait=True)
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
train_epoch_loss, train_epoch_acc = train(
model,
train_loader,
optimizer,
criterion,
device
)
valid_epoch_loss, valid_epoch_acc = validate(
model,
valid_loader,
criterion,
device
)
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
train_acc.append(train_epoch_acc)
valid_acc.append(valid_epoch_acc)
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
print('-'*50)
# Save the loss and accuracy plots.
save_plots(
train_acc,
valid_acc,
train_loss,
valid_loss,
name="training_output"
)
print('TRAINING COMPLETE')
When I run this I get the following error (note the print statements are for analysing what the labels and outputs are right before the error which occurs on the “loss = criterion(outputs, label)” line in the train function:
[INFO]: Epoch 1 of 20
Training
0%| | 0/11 [00:01<?, ?it/s]
label: <class 'torch.Tensor'> tensor([0, 0, 3, 3, 0, 4, 0, 3, 1, 4, 0, 0, 3, 3, 0, 3, 0, 2, 0, 2, 3, 3,
2, 3, 1, 1, 4, 2, 0, 0, 3, 2, 0, 2, 0, 0, 2, 0, 2, 0, 2, 0, 0, 0, 1, 0, 1, 4, 2, 0, 0, 3, 0, 2, 1, 0, 1,
3, 4, 3, 4, 0, 0, 3, 3, 3, 3, 3, 0, 2, 0, 0, 4, 3, 3, 0, 0, 2, 0, 1, 0, 2, 1, 1, 0, 0, 1, 0, 3, 4, 4, 1,
0, 4, 0, 0, 3, 3, 0, 4], dtype=torch.int32) torch.Size([100])
outputs: <class 'torch.Tensor'> tensor([[ 0.1170, -0.1176, -0.3459, ..., -0.2842, 0.0548, 0.5604], [
0.1989, 0.0700, -0.4919, ..., -0.5653, -0.0517, 0.5022], [ 0.2417, -0.0707, -0.6511, ..., -0.4755,
0.0190, 0.5151], ..., [ 0.2001, -0.0171, -0.6624, ..., -0.5510, 0.0850, 0.5551], [ 0.2618, -0.1657,
-0.3384, ..., -0.2892, 0.0925, 0.4892], [ 0.1843, -0.1776, -0.3948, ..., -0.2386, -0.0104,
0.5918]], grad_fn=<AddmmBackward0>) torch.Size([100, 1000])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[80], line 10
8 clear_output(wait=True)
9 print(f"[INFO]: Epoch {epoch+1} of {epochs}")
---> 10 train_epoch_loss, train_epoch_acc = train(
11 model,
12 train_loader,
13 optimizer,
14 criterion,
15 device
16 )
17 valid_epoch_loss, valid_epoch_acc = validate(
18 model,
19 valid_loader,
20 criterion,
21 device
22 )
23 train_loss.append(train_epoch_loss)
Cell In[76], line 22
19 print("label: ", type(label), label, label.shape)
20 print("outputs: ", type(outputs), outputs, outputs.shape)
---> 22 loss = criterion(outputs, label)
23 train_running_loss += loss.item()
...
3057 if size_average is not None or reduce is not None:
3058 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3059 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction),
ignore_index, label_smoothing)
RuntimeError: expected scalar type Long but found Int
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output
settings...
Do I need to change the data types of the dataset? There seems to be something wrong with the MedMNIST dataset resulting in me not being able to use it. Once again the CIFAR10 data set works fine.
Any help or suggestions appreciates 🙂