def flatten(inputs, scope=None):
"""Flattens the input while maintaining the batch_size.
Assumes that the first dimension represents the batch.
Args:
inputs: a tensor of size [batch_size, ...].
scope: Optional scope for op_scope.
Returns:
a flattened tensor with shape [batch_size, k].
Raises:
ValueError: if inputs.shape is wrong.
"""
if len(inputs.get_shape()) < 2:
raise ValueError('Inputs must be have a least 2 dimensions')
dims = inputs.get_shape()[1:]
k = dims.num_elements()
with tf.op_scope([inputs], scope, 'Flatten'):
return tf.reshape(inputs, [-1, k])
评论列表
文章目录