Suppose, I have 9 plots I want to make from a dataframe with 9 columns. So I create 9 axes with fig, axes = plt.subplots(3, 3)
. Now, I iterate over my 9 columns with for name, data_col in data.items()
and do my plotting (for example data_col.value_counts().plot.barh()
).
The problem is that with the array of axes two-dimensional, I will have to write extra code to just get the next axis to plot. Instead of
data = pd.DataFrame(np.random.randint(1, 4, size=(50, 9)))
fig, axes = plt.subplots(3, 3)
for name, data_col in data.items():
data_col.value_counts().plot.barh(ax=plt.get_next_axis())
I will have to do something like this
for i, (name, data_col) in enumerate(data.items()):
ax = axes[i // nrows, i % ncols]
data_col.value_counts().plot.barh(ax=ax)
which is more cumbersome and error-prone.
If I were to plot something simple, like a line or a hist, I could just data.plot.hist()
, with arguments to make it all as nice as I want. But if I need a bit more complex plot, then I have to iterate over it manually.
Is there a simple elegant solution to this? Or am just trying to solve something that doesn’t even need solving?
1