ops.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Tensormodels 作者: asheshjain399 项目源码 文件源码
def flatten(inputs, scope=None):
  """Flattens the input while maintaining the batch_size.

    Assumes that the first dimension represents the batch.

  Args:
    inputs: a tensor of size [batch_size, ...].
    scope: Optional scope for op_scope.

  Returns:
    a flattened tensor with shape [batch_size, k].
  Raises:
    ValueError: if inputs.shape is wrong.
  """
  if len(inputs.get_shape()) < 2:
    raise ValueError('Inputs must be have a least 2 dimensions')
  dims = inputs.get_shape()[1:]
  k = dims.num_elements()
  with tf.op_scope([inputs], scope, 'Flatten'):
    return tf.reshape(inputs, [-1, k])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号