plot_utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:vsi_common 作者: VisionSystemsInc 项目源码 文件源码
def grouped_bar(features, bar_labels=None, group_labels=None, ax=None, colors=None):
  '''
  features.shape like np.array([n_bars, n_groups])

  >>> bars = np.random.rand(5,3)
  >>> grouped_bar(bars)

  >>> group_labels = ['group%d' % i for i in range(bars.shape[1])]
  >>> bar_labels = ['bar%d' % i for i in range(bars.shape[0])]
  >>> grouped_bar(bars, group_labels=group_labels, bar_labels=bar_labels)
  '''

  n_bars, n_groups = features.shape[0:2]

  if ax is None: 
    fig, ax = plt.subplots()
    fig.set_size_inches(9,6)
  else:
    fig = ax.get_figure()

  if colors is None: 
    colors = mpl.cm.spectral(np.linspace(0, 1, n_bars))

  index = np.arange(n_groups)
  bar_width = 1.0/(n_bars) * 0.75

  for j,group in enumerate(features):
    label = bar_labels[j] if bar_labels is not None else None
    ax.bar(index + j*bar_width - bar_width*n_bars/2.0, 
      group, bar_width, color=colors[j], label=label, alpha=0.4)
    ax.margins(0.05,0.0) # so the bar graph is nicely padded

  if bar_labels is not None:
    ax.legend(loc='upper left', bbox_to_anchor=(1.0,1.02), fontsize=14)

  if group_labels is not None:
    ax.set_xticks(index + (n_bars/2.)*bar_width - bar_width*n_bars/2.0)
    ax.set_xticklabels(group_labels, rotation=0.0)
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
      item.set_fontsize(14)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号