I have a code to get output from my model but it is only giving binary segmentation mask like this img
model.load_state_dict(torch.load(os.path.join(output_dir, "best_metric_model.pth")))
model.eval()
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True, logit_thresh=0.5)])
# define inference method
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=cfg.unetr.img_shape,
sw_batch_size=1,
predictor=model,
)
if VAL_AMP:
with torch.cuda.amp.autocast():
return _compute(input)
else:
return _compute(input)
with torch.no_grad():
# select one image to evaluate and visualize the model output
val_data = val_ds[1]
val_input = val_data["image"].unsqueeze(0).to(device)
val_output = inference(val_input)
val_output = post_trans(val_output[0])
plt.figure("image", (24, 6))
for i in range(4):
plt.subplot(1, 4, i + 1)
plt.title(f"image channel {i}")
plt.imshow(val_data["image"][i, :, :, val_input.shape[-1] // 2].detach().cpu(), cmap="gray")
plt.show()
# visualize the 3 channels label corresponding to this image
plt.figure("label", (18, 6))
for i in range(3):
plt.subplot(1, 3, i + 1)
plt.title(f"label channel {i}")
plt.imshow(val_data["label"][i, :, :, val_input.shape[-1] // 2].detach().cpu())
plt.show()
# visualize the 3 channels model output corresponding to this image
plt.figure("output", (18, 6))
for i in range(3):
plt.subplot(1, 3, i + 1)
plt.title(f"output channel {i}")
plt.imshow(val_output[i, :, :, val_input.shape[-1] // 2].detach().cpu())
plt.show()
I want my output to be like this image
whose code is provided here
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
image_channels = ['FLAIR', 'T1w', 'T1gd', 'T2w']
label_channels = ['Tumor Core', 'Whole Tumor', 'Enhancing Tumor']
colors = [
LinearSegmentedColormap.from_list("yellow_red", [(1, 1, 0), (1, 0, 0)], N=256), # Yellow to Red
LinearSegmentedColormap.from_list("purple", [(0.5, 0, 0.5), (1, 0, 1)], N=256), # Purple
LinearSegmentedColormap.from_list("green", [(0, 1, 0), (0, 0.5, 0)], N=256) # Green
]
val_data_example = val_ds[1]
print(f"image shape: {val_data_example['image'].shape}")
plt.figure("image", (24, 6))
for i in range(len(image_channels)):
plt.subplot(1, 4, i + 1)
plt.title(f"{image_channels[i]} image", weight='bold')
plt.imshow(val_data_example["image"][i, :, :, val_data_example['image'].shape[-1] // 2], cmap="gray")
plt.show()
print(f"label shape: {val_data_example['label'].shape}")
plt.figure("image with labels", (24, 6))
for i in range(len(image_channels)):
plt.subplot(1, 4, i + 1)
plt.title(f"{image_channels[i]} with labels", weight='bold')
plt.imshow(val_data_example["image"][i, :, :, val_data_example['image'].shape[-1] // 2], cmap="gray")
for j in range(len(label_channels)):
label_slice = val_data_example["label"][j, :, :, val_data_example['image'].shape[-1] // 2]
plt.imshow(np.ma.masked_where(label_slice == 0, label_slice), cmap=colors[j], alpha=0.5)
plt.show()
So i tried to create a similar code which was
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.transforms import Compose, Activations, AsDiscrete
from monai.inferers import sliding_window_inference
image_channels = ['FLAIR', 'T1w', 'T1gd', 'T2w']
label_channels = ['Tumor Core', 'Whole Tumor', 'Enhancing Tumor']
actual_colors = ['Reds', 'Blues', 'Greens'] # Color maps for the actual labels
predicted_colors = ['Oranges', 'Purples', 'rainbow'] # Contrasting color maps for the predicted labels
# Post-processing for multi-class segmentation
post_trans = Compose([
Activations(softmax=True), # Apply softmax to get probabilities
AsDiscrete(argmax=True) # Select the class with the highest probability
])
# Define inference method
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=cfg.unetr.img_shape,
sw_batch_size=1,
predictor=model,
)
if VAL_AMP:
with torch.cuda.amp.autocast():
return _compute(input)
else:
return _compute(input)
# Load the model
model.load_state_dict(torch.load(os.path.join(output_dir, "best_metric_model.pth")))
model.eval()
# Select one image to evaluate and visualize the model output
val_data = val_ds[1]
val_input = val_data["image"].unsqueeze(0).to(device)
with torch.no_grad():
val_output = inference(val_input)
val_output = post_trans(val_output[0])
# Visualize the image channels
plt.figure("Visualization", (24, 18))
# Visualize image channels
for i in range(len(image_channels)):
plt.subplot(3, 4, i + 1)
plt.title(f"{image_channels[i]} image", weight='bold')
plt.imshow(val_data["image"][i, :, :, val_input.shape[-1] // 2].detach().cpu(), cmap="gray")
plt.axis('off')
# Visualize labels
for i in range(len(image_channels)):
plt.subplot(3, 4, i + 5)
plt.title(f"{image_channels[i]} with labels", weight='bold')
plt.imshow(val_data["image"][i, :, :, val_input.shape[-1] // 2].detach().cpu(), cmap="gray")
plt.axis('off')
# Overlay each true label channel
for j in range(len(label_channels)):
label_slice = val_data["label"][j, :, :, val_input.shape[-1] // 2].detach().cpu()
plt.imshow(np.ma.masked_where(label_slice == 0, label_slice), cmap=actual_colors[j], alpha=0.5)
# Visualize model outputs
for i in range(len(image_channels)):
plt.subplot(3, 4, i + 9)
plt.title(f"{image_channels[i]} with outputs", weight='bold')
plt.imshow(val_data["image"][i, :, :, val_input.shape[-1] // 2].detach().cpu(), cmap="gray")
plt.axis('off')
# Overlay each model output channel
for j in range(len(label_channels)):
output_slice = (val_output == j).float().cpu().squeeze()[..., val_input.shape[-1] // 2] # Extract slice for the current class and squeeze the singleton dimension
plt.imshow(np.ma.masked_where(output_slice == 0, output_slice), cmap=predicted_colors[j], alpha=0.3)
plt.show()
whose output i am getting as :
So i need help on how to get the output as i mentioned in that image
i want my labels to be of different colors and overlayed on each other