def flatten(inputs,
outputs_collections=None,
scope=None):
"""Flattens the input while maintaining the batch_size.
Assumes that the first dimension represents the batch.
Args:
inputs: a tensor of size [batch_size, ...].
outputs_collections: collection to add the outputs.
scope: Optional scope for name_scope.
Returns:
a flattened tensor with shape [batch_size, k].
Raises:
ValueError: if inputs.shape is wrong.
"""
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if (inputs_rank is None) or (inputs_rank < 2):
raise ValueError('Inputs must have a least 2 dimensions.')
dims = inputs_shape[1:]
if not dims.is_fully_defined():
raise ValueError('Inputs 2nd dimension must be defined.')
k = dims.num_elements()
outputs = array_ops.reshape(inputs, [-1, k])
return utils.collect_named_outputs(outputs_collections, sc, outputs)
评论列表
文章目录