def _resize_except_axis(inputs, size, axis, **kwargs):
""" Resize 3D input tensor to size except just one axis. """
perm = np.arange(5)
reverse_perm = np.arange(5)
if axis == 0:
spatial_perm = [2, 3, 1]
reverse_spatial_perm = [3, 1, 2]
elif axis == 1:
spatial_perm = [1, 3, 2]
reverse_spatial_perm = [1, 3, 2]
else:
spatial_perm = [1, 2, 3]
reverse_spatial_perm = [1, 2, 3]
perm[1:4] = spatial_perm
reverse_perm[1:4] = reverse_spatial_perm
x = tf.transpose(inputs, perm)
if isinstance(size, tf.Tensor):
size = tf.unstack(size)
size = [size[i-1] for i in spatial_perm]
size = tf.stack(size)
else:
size = [size[i-1] for i in spatial_perm]
real_size, static_shape = _calc_size_after_resize(x, size, [0, 1])
real_size = size[:-1]
array = tf.TensorArray(tf.float32, size=tf.shape(x)[-2])
partial_sl = [slice(None)] * 5
def _loop(idx, array):
partial_sl[-2] = idx
tensor = x[partial_sl]
tensor = tf.image.resize_bilinear(tensor, size=real_size, name='resize_2d', **kwargs)
array = array.write(idx, tensor)
return [idx+1, array]
i = 0
_, array = tf.while_loop(lambda i, array: i < tf.shape(x)[-2], _loop, [i, array])
array = array.stack()
array = tf.transpose(array, [1, 2, 3, 0, 4])
array.set_shape(static_shape)
array = tf.transpose(array, reverse_perm)
return array
评论列表
文章目录