def gather_indices_2d(x, block_shape, block_stride):
"""Getting gather indices."""
# making an identity matrix kernel
kernel = tf.eye(block_shape[0] * block_shape[1])
kernel = reshape_range(kernel, 0, 1, [block_shape[0], block_shape[1], 1])
# making indices [1, h, w, 1] to appy convs
x_shape = common_layers.shape_list(x)
indices = tf.range(x_shape[2] * x_shape[3])
indices = tf.reshape(indices, [1, x_shape[2], x_shape[3], 1])
indices = tf.nn.conv2d(
tf.cast(indices, tf.float32),
kernel,
strides=[1, block_stride[0], block_stride[1], 1],
padding="VALID")
# making indices [num_blocks, dim] to gather
dims = common_layers.shape_list(indices)[:3]
if all([isinstance(dim, int) for dim in dims]):
num_blocks = functools.reduce(operator.mul, dims, 1)
else:
num_blocks = tf.reduce_prod(dims)
indices = tf.reshape(indices, [num_blocks, -1])
return tf.cast(indices, tf.int32)
评论列表
文章目录