eval_tools.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:hart 作者: akosiorek 项目源码 文件源码
def image_series_summary(tag, imgs, max_timesteps=10):
    # take only 3 items from the minibatch
    imgs = imgs[:, :3]

    # assume img.shape == (T, batch_size, n_obj, H, W, C)
    # let's log only for 1st obj
    tf.cond(tf.equal(tf.rank(imgs), 6), lambda: imgs[:, :, 0], lambda: imgs)

    shape = (max_timesteps,) + tuple(imgs.get_shape()[1:])
    nt = tf.shape(imgs)[0]

    def pad():
        paddings = tf.concat(axis=0, values=([[0, max_timesteps - nt]], tf.zeros((len(shape) - 1, 2), tf.int32)))
        return tf.pad(imgs, paddings)

    imgs = tf.cond(tf.greater(nt, max_timesteps), lambda: imgs[:max_timesteps], pad)
    imgs.set_shape(shape)
    imgs = tf.squeeze(imgs)
    imgs = tf.unstack(imgs)

    # concatenate along the columns
    imgs = tf.concat(axis=2, values=imgs)
    tf.summary.image(tag, imgs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号