def image_series_summary(tag, imgs, max_timesteps=10):
# take only 3 items from the minibatch
imgs = imgs[:, :3]
# assume img.shape == (T, batch_size, n_obj, H, W, C)
# let's log only for 1st obj
tf.cond(tf.equal(tf.rank(imgs), 6), lambda: imgs[:, :, 0], lambda: imgs)
shape = (max_timesteps,) + tuple(imgs.get_shape()[1:])
nt = tf.shape(imgs)[0]
def pad():
paddings = tf.concat(axis=0, values=([[0, max_timesteps - nt]], tf.zeros((len(shape) - 1, 2), tf.int32)))
return tf.pad(imgs, paddings)
imgs = tf.cond(tf.greater(nt, max_timesteps), lambda: imgs[:max_timesteps], pad)
imgs.set_shape(shape)
imgs = tf.squeeze(imgs)
imgs = tf.unstack(imgs)
# concatenate along the columns
imgs = tf.concat(axis=2, values=imgs)
tf.summary.image(tag, imgs)
评论列表
文章目录