plot_utils.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:rec-attend-public 作者: renmengye 项目源码 文件源码
def plot_input(fname, x, y_gt, s_gt, max_items_per_row=9):
  """Plot input, transformed input and output groundtruth sequence.
    """
  num_ex = y_gt.shape[0]
  num_items = y_gt.shape[1]
  num_row, num_col, calc = calc_row_col(
      num_ex, num_items, max_items_per_row=max_items_per_row)

  f1, axarr = plt.subplots(num_row, num_col, figsize=(20, num_row))
  set_axis_off(axarr, num_row, num_col)
  cmap = ['r', 'y', 'c', 'g', 'm']

  for ii in xrange(num_ex):
    _x = x[ii]
    _x = _x[:, :, [2, 1, 0]]
    # _x = x[ii, :, :, [2, 1, 0]]
    for jj in xrange(num_items):
      row, col = calc(ii, jj)
      axarr[row, col].imshow(_x)
      nz = y_gt[ii, jj].nonzero()
      if nz[0].size > 0:
        top_left_x = nz[1].min()
        top_left_y = nz[0].min()
        bot_right_x = nz[1].max() + 1
        bot_right_y = nz[0].max() + 1
        axarr[row, col].add_patch(
            patches.Rectangle(
                (top_left_x, top_left_y),
                bot_right_x - top_left_x,
                bot_right_y - top_left_y,
                fill=False,
                color=cmap[jj % len(cmap)]))
        axarr[row, col].add_patch(
            patches.Rectangle(
                (top_left_x, top_left_y - 25),
                25,
                25,
                fill=True,
                color=cmap[jj % len(cmap)]))
        axarr[row, col].text(
            top_left_x + 5, top_left_y - 5, '{}'.format(jj), size=5)

  plt.tight_layout(pad=2.0, w_pad=0.0, h_pad=0.0)
  plt.savefig(fname, dpi=150)
  plt.close('all')
  pass
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号