resize.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:dataset 作者: analysiscenter 项目源码 文件源码
def _resize_except_axis(inputs, size, axis, **kwargs):
    """ Resize 3D input tensor to size except just one axis. """
    perm = np.arange(5)
    reverse_perm = np.arange(5)

    if axis == 0:
        spatial_perm = [2, 3, 1]
        reverse_spatial_perm = [3, 1, 2]
    elif axis == 1:
        spatial_perm = [1, 3, 2]
        reverse_spatial_perm = [1, 3, 2]
    else:
        spatial_perm = [1, 2, 3]
        reverse_spatial_perm = [1, 2, 3]

    perm[1:4] = spatial_perm
    reverse_perm[1:4] = reverse_spatial_perm
    x = tf.transpose(inputs, perm)

    if isinstance(size, tf.Tensor):
        size = tf.unstack(size)
        size = [size[i-1] for i in spatial_perm]
        size = tf.stack(size)
    else:
        size = [size[i-1] for i in spatial_perm]

    real_size, static_shape = _calc_size_after_resize(x, size, [0, 1])
    real_size = size[:-1]
    array = tf.TensorArray(tf.float32, size=tf.shape(x)[-2])
    partial_sl = [slice(None)] * 5

    def _loop(idx, array):
        partial_sl[-2] = idx
        tensor = x[partial_sl]
        tensor = tf.image.resize_bilinear(tensor, size=real_size, name='resize_2d', **kwargs)
        array = array.write(idx, tensor)
        return [idx+1, array]
    i = 0
    _, array = tf.while_loop(lambda i, array: i < tf.shape(x)[-2], _loop, [i, array])
    array = array.stack()
    array = tf.transpose(array, [1, 2, 3, 0, 4])
    array.set_shape(static_shape)
    array = tf.transpose(array, reverse_perm)
    return array
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号