I want a specific subplot layout and found GridSpec to have this possibility. Here, the upper chart is a bar plot of the counts per date, the left is a bar plot of the counts per variable (over dates) and the main plot is a seaborn heatmap from a pandas dataframe. The index in the dataframe represents the name of the variables and the columns are datetimes. I would like to only have the axes (y and x) at the main plot. X should be the years, and Y should be the variables. However, somehow there is a messup with the ticks. I provide a example below.
To get the data (note, the NaNs are intended):
from random import randrange
import pandas as pd
import numpy as np
xnobs = 120
ynobs = 10
tmp = pd.DataFrame(np.nan,index=['V'+str(i) for i in range(ynobs)], columns=pd.date_range(start='1/1/2018', periods=xnobs))
for _ in range(xnobs*ynobs*2):
tmp.iloc[randrange(0,ynobs),randrange(0,xnobs)] = np.abs(np.random.randn())
To plot:
import seaborn as sns
import matplotlib.pyplot as plt
title=''
ypos=0.885
# plot sparsity, but for some reasons it does not show the date ticks
fig = plt.figure(figsize=(10, 10))
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
fig.suptitle(title, y= ypos, fontsize=12)
grid = plt.GridSpec(12, 12, hspace=0.2, wspace=0.5)
main_ax = fig.add_subplot(grid[1:-1, :-2])
y_hist = fig.add_subplot(grid[1:-1, -2:], yticklabels=[], sharey=main_ax)
x_hist = fig.add_subplot(grid[0, :-2], xticklabels=[], sharex=main_ax)
# plot the heatmap
g = sns.heatmap(tmp, cmap=matplotlib.cm.Reds, cbar=False, ax=main_ax)
# plot to the top, how many variables were selected for any given month
tmp.count(axis=0).plot.bar(ax=x_hist, width=1, align='edge', color='darkred', sharex=True, use_index=False, xticks=[])
# plot to the right, how many times a variable was selected over all months
tmp.count(axis=1).plot.barh(ax=y_hist,width=1,align='edge', color='darkred', sharey=True)
y_hist.invert_yaxis()
main_ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
main_ax.set_yticklabels(['']+tmp.index.tolist())
# Set x-ticks to start of each year or significant date
year_starts = [i for i in range(len(tmp.columns)) if tmp.columns[i].month == 1]
main_ax.set_xticks(year_starts) # Set the ticks to the start of each year
main_ax.set_xticklabels([tmp.columns[y].year for y in year_starts]) # Set labels as year
Thomasius Wa is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.