def tf_batch_map_offsets(input, offsets, order=1):
"""Batch map offsets into input
Parameters
---------
input : tf.Tensor. shape = (b, s, s)
offsets: tf.Tensor. shape = (b, s, s, 2)
Returns
-------
tf.Tensor. shape = (b, s, s)
"""
input_shape = tf.shape(input)
batch_size = input_shape[0]
input_size = input_shape[1]
offsets = tf.reshape(offsets, (batch_size, -1, 2))
grid = tf.meshgrid(
tf.range(input_size), tf.range(input_size), indexing='ij'
)
grid = tf.stack(grid, axis=-1)
grid = tf.cast(grid, 'float32')
grid = tf.reshape(grid, (-1, 2))
grid = tf_repeat_2d(grid, batch_size)
coords = offsets + grid
mapped_vals = tf_batch_map_coordinates(input, coords)
return mapped_vals
评论列表
文章目录