def _build(self, inputs):
"""Connects the `TileByDim` module into the graph.
Args:
inputs: `Tensor` to tile.
Returns:
The tiled tensor.
"""
shape_inputs = inputs.get_shape().as_list()
rank = len(shape_inputs)
# Builds default lists for multiples to pass to `tf.tile`.
full_multiples = [1] * rank
# Updates lists with what the user provided.
for dim, multiple in zip(self._dims, self._multiples):
full_multiples[dim] = multiple
return tf.tile(inputs, multiples=full_multiples)
评论列表
文章目录