Given plot like this:
Here is the code:
#@title Visualizer Function
def plot_delineation_comparison(Xt, yt, Xp, yp, start, stop=None, rec_name = '-', lead_name = '-', pathology_name = '-'):
if stop is None:
stop = -1
Xt = Xt[start:stop]
yt = yt[start:stop]
Xp = Xp[start:stop]
yp = yp[start:stop]
# Get mask of every class for prediction
bl_pred = yp == 0
p_pred = yp == 1
qrs_pred = yp == 2
t_pred = yp == 3
# Get mask of every class for ground truth
bl_true = yt == 0
p_true = yt == 1
qrs_true = yt == 2
t_true = yt == 3
# Create figure with two rows and one column
fig, (ax1, ax2) = plt.subplots(
2,
1,
figsize=(16, 8),
sharex=True,
gridspec_kw={"hspace": 0},
)
# Plotting for prediction
prev_class = None
start_idx = 0
for i in range(stop - start):
current_class = None
if bl_pred[i]:
current_class = 'grey'
elif p_pred[i]:
current_class = 'orange'
elif qrs_pred[i]:
current_class = 'green'
elif t_pred[i]:
current_class = 'purple'
if current_class != prev_class:
if prev_class is not None:
ax2.axvspan(start_idx, i, color=prev_class, alpha=0.5)
start_idx = i
prev_class = current_class
# Fill the last region
if prev_class is not None:
ax2.axvspan(start_idx, stop - start, color=prev_class, alpha=0.5)
# Plotting for ground truth
prev_class = None
start_idx = 0
for i in range(stop - start):
current_class = None
if bl_true[i]:
current_class = 'grey'
elif p_true[i]:
current_class = 'orange'
elif qrs_true[i]:
current_class = 'green'
elif t_true[i]:
current_class = 'purple'
if current_class != prev_class:
if prev_class is not None:
ax1.axvspan(start_idx, i, color=prev_class, alpha=0.5)
start_idx = i
prev_class = current_class
# Fill the last region
if prev_class is not None:
ax1.axvspan(start_idx, stop - start, color=prev_class, alpha=0.5)
# First row for ground truth (X_unseen, y_true)
ax1.plot(Xt, color='blue')
ax1.set_ylabel('Ground Truth')
# draw baseline at y=0
ax1.axhline(y=0, color='red', linestyle='-', lw=0.5)
# Second row for ground truth (X_pred y_pred)
ax2.plot(Xp, color='blue')
ax2.axhline(y=0, color='red', linestyle='-', lw=0.5)
ax2.set_xlim([0, stop - start])
ax2.set_ylabel('Prediction')
ax2.set_xlabel('Index')
# Retrieve the current x-tick locations
current_xticks = ax2.get_xticks()
# Define the new x-tick labels based on absolute start and end
new_xtick_labels = [int(x + start) for x in current_xticks]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
ax2.set_xticklabels(new_xtick_labels)
cm = ConfusionMatrix(actual_vector=yt.flatten(), predict_vector=yp.flatten(), transpose=True)
# Handle if not number type
cm.PPV = [0 if not type(x) == float else x for _,x in cm.PPV.items()]
cm.TPR = [0 if not type(x) == float else x for _,x in cm.TPR.items()]
# Make length of PPV and TPR consistent, fill with zero if not
if len(cm.PPV) < 4:
cm.PPV += [0] * (4 - len(cm.PPV))
if len(cm.TPR) < 4:
cm.TPR += [0] * (4 - len(cm.TPR))
if len(cm.F1) < 4:
# convert F1 to list and fixed it with length 4
cm.F1 = list(cm.F1.values())
cm.F1 += [0] * (4 - len(cm.F1))
notes_list = [
f"Recall",
f"BL : {cm.TPR[0]:.2f}",
f"P : {cm.TPR[1]:.2f}",
f"QRS : {cm.TPR[2]:.2f}",
f"T : {cm.TPR[3]:.2f}",
f"",
f"Rec Name : {rec_name}",
f"Lead : {lead_name}",
f"Pathology : {pathology_name}",
f"Unit : mV",
f"Sample Rate: 360Hz",
f"SNR(Pr./GT): {calculate_snr(Xp, Xt)}dB",
f"",
f"Precission",
f"BL : {cm.PPV[0]:.2f}",
f"P : {cm.PPV[1]:.2f}",
f"QRS : {cm.PPV[2]:.2f}",
f"T : {cm.PPV[3]:.2f}",
]
# notes_list += catatan
ax1.set_title(f"F1-Score | BL: {cm.F1[0]:.2f} | P: {cm.F1[1]:.2f} | QRS: {cm.F1[2]:.2f} | T: {cm.F1[3]:.2f}")
code_font = FontProperties(family='monospace', style='normal', variant='normal', size=8)
for i, note in enumerate(notes_list[:]):
plt.text(1.01, 0.95 - i * 0.1, note, transform=ax1.transAxes, fontsize=10, va='top', ha='left', fontproperties=code_font)
# for i, note in enumerate(notes_list[5:]):
# plt.text(1.01, 0.95 - i * 0.1, note, transform=ax2.transAxes, fontsize=10, va='top', ha='left', fontproperties=code_font)
plt.subplots_adjust(top=0.5)
# add legend with offset
# Create custom Line2D objects with desired colors
custom_lines = [
Line2D([0], [0], color='grey', lw=4, alpha=0.5),
Line2D([0], [0], color='orange', lw=4, alpha=0.5),
Line2D([0], [0], color='green', lw=4, alpha=0.5),
Line2D([0], [0], color='purple', lw=4, alpha=0.5)
]
# Add legend with custom lines
ax1.legend(custom_lines, ['BL', 'P', 'QRS', 'T'], loc='upper left')
plt.show()
I’m working on visualizing some data using Matplotlib, and I’ve encountered an issue with how the y-axis is displayed. Specifically, I have a plot where the y-axis shows non-integer values, and I would like to hide these non-integer values so that only integers are displayed.
Additionally, I need to fix the y-axis range between -2 and 2 for both the first and second rows of subplots in the figure. The goal is to maintain consistency in the y-axis range across these subplots.