def plot_batches(fulldata, cols=None, out_file=None, site_labels='left'):
fulldata = fulldata.sort_values(by=['database', 'site']).copy()
sites = fulldata.site.values.ravel().tolist()
if cols is None:
numdata = fulldata.select_dtypes([np.number])
else:
numdata = fulldata[cols]
numdata = numdata[cols]
colmin = numdata.min()
numdata = (numdata - colmin)
colmax = numdata.max()
numdata = numdata / colmax
fig, ax = plt.subplots(figsize=(20, 10))
ax.imshow(numdata.values, cmap=plt.cm.viridis, interpolation='nearest', aspect='auto')
locations = []
spines = []
fulldata['index'] = range(len(fulldata))
for site in list(set(sites)):
indices = fulldata.loc[fulldata.site == site, 'index'].values.ravel().tolist()
locations.append(int(np.average(indices)))
spines.append(indices[0])
if site_labels == 'right':
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
plt.xticks(range(numdata.shape[1]), numdata.columns.ravel().tolist(), rotation='vertical')
plt.yticks(locations, list(set(sites)))
for line in spines[1:]:
plt.axhline(y=line, color='w', linestyle='-')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.grid(False)
ticks_font = FontProperties(
family='FreeSans', style='normal', size=14,
weight='normal', stretch='normal')
for label in ax.get_yticklabels():
label.set_fontproperties(ticks_font)
ticks_font = FontProperties(
family='FreeSans', style='normal', size=12,
weight='normal', stretch='normal')
for label in ax.get_xticklabels():
label.set_fontproperties(ticks_font)
if out_file is not None:
fig.savefig(out_file, bbox_inches='tight', pad_inches=0, dpi=300)
return fig
评论列表
文章目录