def grouped_bar(features, bar_labels=None, group_labels=None, ax=None, colors=None):
'''
features.shape like np.array([n_bars, n_groups])
>>> bars = np.random.rand(5,3)
>>> grouped_bar(bars)
>>> group_labels = ['group%d' % i for i in range(bars.shape[1])]
>>> bar_labels = ['bar%d' % i for i in range(bars.shape[0])]
>>> grouped_bar(bars, group_labels=group_labels, bar_labels=bar_labels)
'''
n_bars, n_groups = features.shape[0:2]
if ax is None:
fig, ax = plt.subplots()
fig.set_size_inches(9,6)
else:
fig = ax.get_figure()
if colors is None:
colors = mpl.cm.spectral(np.linspace(0, 1, n_bars))
index = np.arange(n_groups)
bar_width = 1.0/(n_bars) * 0.75
for j,group in enumerate(features):
label = bar_labels[j] if bar_labels is not None else None
ax.bar(index + j*bar_width - bar_width*n_bars/2.0,
group, bar_width, color=colors[j], label=label, alpha=0.4)
ax.margins(0.05,0.0) # so the bar graph is nicely padded
if bar_labels is not None:
ax.legend(loc='upper left', bbox_to_anchor=(1.0,1.02), fontsize=14)
if group_labels is not None:
ax.set_xticks(index + (n_bars/2.)*bar_width - bar_width*n_bars/2.0)
ax.set_xticklabels(group_labels, rotation=0.0)
for item in (ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(14)
评论列表
文章目录