def plot_input(fname, x, y_gt, s_gt, max_items_per_row=9):
"""Plot input, transformed input and output groundtruth sequence.
"""
num_ex = y_gt.shape[0]
num_items = y_gt.shape[1]
num_row, num_col, calc = calc_row_col(
num_ex, num_items, max_items_per_row=max_items_per_row)
f1, axarr = plt.subplots(num_row, num_col, figsize=(20, num_row))
set_axis_off(axarr, num_row, num_col)
cmap = ['r', 'y', 'c', 'g', 'm']
for ii in xrange(num_ex):
_x = x[ii]
_x = _x[:, :, [2, 1, 0]]
# _x = x[ii, :, :, [2, 1, 0]]
for jj in xrange(num_items):
row, col = calc(ii, jj)
axarr[row, col].imshow(_x)
nz = y_gt[ii, jj].nonzero()
if nz[0].size > 0:
top_left_x = nz[1].min()
top_left_y = nz[0].min()
bot_right_x = nz[1].max() + 1
bot_right_y = nz[0].max() + 1
axarr[row, col].add_patch(
patches.Rectangle(
(top_left_x, top_left_y),
bot_right_x - top_left_x,
bot_right_y - top_left_y,
fill=False,
color=cmap[jj % len(cmap)]))
axarr[row, col].add_patch(
patches.Rectangle(
(top_left_x, top_left_y - 25),
25,
25,
fill=True,
color=cmap[jj % len(cmap)]))
axarr[row, col].text(
top_left_x + 5, top_left_y - 5, '{}'.format(jj), size=5)
plt.tight_layout(pad=2.0, w_pad=0.0, h_pad=0.0)
plt.savefig(fname, dpi=150)
plt.close('all')
pass
评论列表
文章目录