h = 16
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
for i, q_id in enumerate(sorted_indices[0]):
logit = itm_logit[:, q_id, :]
prob = torch.nn.functional.softmax(logit, dim=1)
name = f'{prob[0, 1]:.3f}_query_id_{q_id}'
# Attention map
attention_map = avg_cross_att[0, q_id, :-1].view(h, h).detach().cpu().numpy()
# Image
raw_image_resized = raw_image.resize((596, 596))
ax[0].set_title(name)
ax[0].imshow(attention_map, cmap='viridis')
ax[0].axis('off')
ax[1].set_title(caption)
ax[1].imshow(raw_image_resized)
ax[1].axis('off')
ax[2].set_title(f'Overlay: {name}')
ax[2].imshow(raw_image_resized)
ax[2].imshow(attention_map, cmap='viridis', alpha=0.6)
ax[2].axis('off')
ax[0].set_aspect('equal')
ax[1].set_aspect('equal')
ax[2].set_aspect('equal')
plt.tight_layout()
plt.savefig(f"./att_maps/{name}.jpg")
plt.show()
break
What I am trying to do is overlay the attention weights on top of the image (on thrid axes), so I can see which part of the image attention weight is more focused on.
However, the code that I put only overlap the attention weight on top of the image.
What might be the problem in this case?