def bilateral_slice(grid, guide, name=None):
"""Slices into a bilateral grid using the guide map.
Args:
grid: (Tensor) [batch_size, grid_h, grid_w, depth, n_outputs]
grid to slice from.
guide: (Tensor) [batch_size, h, w ] guide map to slice along.
name: (string) name for the operation.
Returns:
sliced: (Tensor) [batch_size, h, w, n_outputs] sliced output.
"""
with tf.name_scope(name):
gridshape = grid.get_shape().as_list()
if len(gridshape) == 6:
_, _, _, _, n_out, n_in = gridshape
grid = tf.concat(tf.unstack(grid, None, axis=5), 4)
sliced = hdrnet_ops.bilateral_slice(grid, guide)
if len(gridshape) == 6:
sliced = tf.stack(tf.split(sliced, n_in, axis=3), axis=4)
return sliced
# pylint: enable=redefined-builtin
评论列表
文章目录