As stated in the title. I am getting the error shown there. The error occurs at the line where it says outputs = model(images)
. I am trying to use swin transformer for image classification on 5 classes and am fine tuning the last few layers only. Here is the code:
# Load the pre-trained Swin Transformer model
base_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
# Modify the model to include global average pooling before the classification head
class SwinClassifier(nn.Module):
def __init__(self, base_model, num_classes):
super(SwinClassifier, self).__init__()
self.base_model = base_model
self.global_pool = nn.AdaptiveAvgPool2d(1) # Global average pooling to reduce spatial dimensions to 1x1
self.fc = nn.Linear(base_model.num_features, num_classes) # Fully connected layer
def forward(self, x):
x = self.base_model.forward_features(x) # Extract features
print(f"Shape after forward_features: {x.shape}") # Debugging line
x = self.global_pool(x) # Apply global average pooling (output shape: [batch_size, num_features, 1, 1])
print(f"Shape after global pooling: {x.shape}") # Debugging line
x = x.view(x.size(0), -1) # Flatten the tensor to shape [batch_size, num_features]
print(f"Shape after flattening: {x.shape}") # Debugging line
x = self.fc(x) # Classification head to get the final output
return x
# Create an instance of the modified model
model = SwinClassifier(base_model, num_classes=5)
# Unfreeze the last 4 layers (blocks 9 to 12)
for name, param in model.base_model.named_parameters():
if "layers.3.blocks.9" in name or "layers.3.blocks.10" in name or "layers.3.blocks.11" in name or "layers.3.blocks.12" in name:
param.requires_grad = True # Unfreeze the last 4 layers
else:
param.requires_grad = False # Freeze the rest
# Ensure that the new head parameters are trainable
for param in model.fc.parameters():
param.requires_grad = True
# Move model to device
model.to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
# Early stopping parameters
patience = 5 # Number of epochs with no improvement before stopping
best_val_loss = float('inf') # Initialize the best validation loss to infinity
epochs_without_improvement = 0 # Counter to track epochs without improvement
best_model_wts = None # To store the best model's weights
# Logging setup
logging.basicConfig(filename='train_log_MessiSwinFineTune3LyrAUCF1Kappa.txt', level=logging.INFO, format='%(asctime)s - %(message)s')
# a
epochs = 30 # Set the maximum number of epochs
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
all_preds = [] # List to store all predictions for confusion matrix
all_labels = [] # List to store all true labels for confusion matrix
# Training loop
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass through the modified Swin model
outputs = model(images)
# Compute the loss
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
train_acc = 100 * correct / total
avg_train_loss = running_loss / len(train_loader)
Shape after forward_features: torch.Size([32, 7, 7, 1024])
Shape after global pooling: torch.Size([32, 7, 1, 1])
Shape after flattening: torch.Size([32, 7]) printed and runtime error saying mat1 and mat2 shapes cannot be multiplied (32×7 and 1024×5)
Above is printed as the output along with the error. Can someone tell me what exactly is wrong.
Muhammad Edexcel is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.