image_sample.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ternarynet 作者: czhu95 项目源码 文件源码
def ImageSample(inputs):
    """
    Sample the template image, using the given coordinate, by bilinear interpolation.
    It mimics the same behavior described in:
    `Spatial Transformer Networks <http://arxiv.org/abs/1506.02025>`_.

    :param input: [template, mapping]. template of shape NHWC. mapping of
        shape NHW2, where each pair of the last dimension is a (y, x) real-value
        coordinate.
    :returns: a NHWC output tensor.
    """
    template, mapping = inputs
    assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4

    mapping = tf.maximum(mapping, 0.0)
    lcoor = tf.cast(mapping, tf.int32)  # floor
    ucoor = lcoor + 1

    # has to cast to int32 and then cast back
    # tf.floor have gradient 1 w.r.t input
    # TODO bug fixed in #951
    diff = mapping - tf.cast(lcoor, tf.float32)
    neg_diff = 1.0 - diff   #bxh2xw2x2

    lcoory, lcoorx = tf.split(3, 2, lcoor)
    ucoory, ucoorx = tf.split(3, 2, ucoor)

    lyux = tf.concat(3, [lcoory, ucoorx])
    uylx = tf.concat(3, [ucoory, lcoorx])

    diffy, diffx = tf.split(3, 2, diff)
    neg_diffy, neg_diffx = tf.split(3, 2, neg_diff)

    #prod = tf.reduce_prod(diff, 3, keep_dims=True)
    #diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
                          #tf.reduce_max(diff), diff],
                    #summarize=50)

    return tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
           sample(template, ucoor) * diffx * diffy,
           sample(template, lyux) * neg_diffy * diffx,
           sample(template, uylx) * diffy * neg_diffx], name='sampled')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号