def ImageSample(inputs):
"""
Sample the template image, using the given coordinate, by bilinear interpolation.
It mimics the same behavior described in:
`Spatial Transformer Networks <http://arxiv.org/abs/1506.02025>`_.
:param input: [template, mapping]. template of shape NHWC. mapping of
shape NHW2, where each pair of the last dimension is a (y, x) real-value
coordinate.
:returns: a NHWC output tensor.
"""
template, mapping = inputs
assert template.get_shape().ndims == 4 and mapping.get_shape().ndims == 4
mapping = tf.maximum(mapping, 0.0)
lcoor = tf.cast(mapping, tf.int32) # floor
ucoor = lcoor + 1
# has to cast to int32 and then cast back
# tf.floor have gradient 1 w.r.t input
# TODO bug fixed in #951
diff = mapping - tf.cast(lcoor, tf.float32)
neg_diff = 1.0 - diff #bxh2xw2x2
lcoory, lcoorx = tf.split(3, 2, lcoor)
ucoory, ucoorx = tf.split(3, 2, ucoor)
lyux = tf.concat(3, [lcoory, ucoorx])
uylx = tf.concat(3, [ucoory, lcoorx])
diffy, diffx = tf.split(3, 2, diff)
neg_diffy, neg_diffx = tf.split(3, 2, neg_diff)
#prod = tf.reduce_prod(diff, 3, keep_dims=True)
#diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
#tf.reduce_max(diff), diff],
#summarize=50)
return tf.add_n([sample(template, lcoor) * neg_diffx * neg_diffy,
sample(template, ucoor) * diffx * diffy,
sample(template, lyux) * neg_diffy * diffx,
sample(template, uylx) * diffy * neg_diffx], name='sampled')
评论列表
文章目录