layers.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def flatten(inputs,
            outputs_collections=None,
            scope=None):
  """Flattens the input while maintaining the batch_size.

    Assumes that the first dimension represents the batch.

  Args:
    inputs: a tensor of size [batch_size, ...].
    outputs_collections: collection to add the outputs.
    scope: Optional scope for name_scope.

  Returns:
    a flattened tensor with shape [batch_size, k].
  Raises:
    ValueError: if inputs.shape is wrong.
  """
  with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
    inputs = ops.convert_to_tensor(inputs)
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if (inputs_rank is None) or (inputs_rank < 2):
      raise ValueError('Inputs must have a least 2 dimensions.')
    dims = inputs_shape[1:]
    if not dims.is_fully_defined():
      raise ValueError('Inputs 2nd dimension must be defined.')
    k = dims.num_elements()
    outputs = array_ops.reshape(inputs, [-1, k])
    return utils.collect_named_outputs(outputs_collections, sc, outputs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号