I have a dataframe that I’m plotting using sns.histplot to make ~9 subplots. I do not want the figure to have 9 legends. I’d like 1 legend for the entire figure, since most of the legends are identical anyway. However, since not all the subplots are 100% identical (some have an extra “crop” or are missing one), I can’t simply remove all the legends except for 1.
I was able to do this for sns.barplot, but I the same lines of code don’t seem to work with sns.histplot (and these data need to be plotted with sns.histplot).
I have some dummy code below to tinker with. I commented out the lines that I’m
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Create a dummy DataFrame with random data
crops = [
"apple",
"banana",
"orange",
"potato",
"zucchini",
"kale",
"strawberry",
"raspberry",
"turnip",
"onion",
]
farms = [
"n",
"s",
"e",
"w",
]
np.random.seed(42)
n_samples = 100
counts = np.random.randint(1, 200, size=n_samples)
choices = np.random.choice(crops, size=n_samples)
locations = np.random.choice(farms, size=n_samples)
df = pd.DataFrame({"counts": counts, "crops": choices, "farms": locations})
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
h = [] # Initialize an empty list to collect handles
l = [] # Initialize an empty list to collect labels
rowindex = [0, 0, 1, 1]
colindex = [0, 1, 0, 1]
selection = ["farms == 'n'", "farms == 's'", "farms == 'e'", "farms == 'w'"]
for row, col, sele in zip(rowindex, colindex, selection):
g = sns.histplot(
data=df.query(sele),
x="counts",
kde=True,
stat="count",
hue="crops", # Add the hue parameter
ax=axes[row, col],
)
handles, labels = g.get_legend_handles_labels()
h.extend(handles)
l.extend(labels)
axes[row, col].get_legend().remove() # Remove individual legends from subplots
# Create a single legend for the entire figure
by_label = dict(zip(l, h))
g.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(0.9,0.65), loc='upper right')
plt.show()
With the following lines, I’m trying to collect all the legends from each subplot and throw it into handles[] and labels[] to then turn it into a dictionary (to remove duplicate entries). I’ve tried g.get_legend_handles_labels()
and axes[row,col].get_legend_handles_labels()
, and fig.
, ax.
, but each thing turns up empty when I print the contents of handles and labels. I’m really confused because this did work with sns.barplot (but maybe they define artists differently? or store them differently?).
handles, labels = g.get_legend_handles_labels()
h.extend(handles)
l.extend(labels)
axes[row, col].get_legend().remove() # Remove individual legends from subplots
# Create a single legend for the entire figure
by_label = dict(zip(l, h))
g.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(0.9,0.65), loc='upper right')
Commenting those lines out, this is the plot:
With those lines in the code, I get no legends:
Ideally, I’d have 1 legend (outside the plot) with all the “crops” listed once (the order isn’t essential). Does anyone know what I’m doing wrong?
As mentioned, this WILL WORK with sns.barplot. If you change the g =
bit to this:
g = sns.barplot(
data=df.query(sele),
x="counts",
y="counts",
hue="crops", # Add the hue parameter
ax=axes[row, col],
)
And keep the rest the same, then I get a figure like the below, which is essentially what I’m going for: 1 legend with all the “crops” listed only once with their respective label marker. So it has to be possible with sns.histplot…. right?