def _build(self, inputs):
"""Connects the MergeDims module into the graph.
Args:
inputs: Tensor or a nested list of Tensors to merge. Its rank must be
greater than or equal to `start` + `size`.
Returns:
The merged Tensor or a nested list of merged Tensors.
Raises:
ValueError: If any of the `inputs` tensors has insufficient rank.
"""
if nest.is_sequence(inputs):
merged_tensors = [self._merge(tensor) for tensor in nest.flatten(inputs)]
return nest.pack_sequence_as(inputs, merged_tensors)
# inputs is a single tf.Tensor
return self._merge(inputs)
评论列表
文章目录