def att_plot(top_labels, gt_ind, probs, fn):
# plt.figure(figsize=(5, 5))
#
# color_dict = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
# colors = [color_dict[c] for c in
# ['lightcoral', 'steelblue', 'forestgreen', 'darkviolet', 'sienna', 'dimgrey',
# 'darkorange', 'gold']]
# colors[gt_ind] = color_dict['crimson']
# w = 0.9
# plt.bar(np.arange(len(top_labels)), probs, w, color=colors, alpha=.9, label='data')
# plt.axhline(0, color='black')
# plt.ylim([0, 1])
# plt.xticks(np.arange(len(top_labels)), top_labels, fontsize=6)
# plt.subplots_adjust(bottom=.15)
# plt.tight_layout()
# plt.savefig(fn)
lab = deepcopy(top_labels)
lab[gt_ind] += ' (gt)'
d = pd.DataFrame(data={'probs': probs, 'labels':lab})
fig, ax = plt.subplots(figsize=(4,5))
ax.tick_params(labelsize=15)
sns.barplot(y='labels', x='probs', ax=ax, data=d, orient='h', ci=None)
ax.set(xlim=(0,1))
for rect, label in zip(ax.patches,lab):
w = rect.get_width()
ax.text(w+.02, rect.get_y() + rect.get_height()*4/5, label, ha='left', va='bottom',
fontsize=25)
# ax.yaxis.set_label_coords(0.5, 0.5)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.get_yaxis().set_visible(False)
ax.get_xaxis().label.set_visible(False)
fig.savefig(fn, bbox_inches='tight', transparent=True)
plt.close('all')
评论列表
文章目录