loss_ops.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _num_present(losses, weight, per_batch=False):
  """Computes the number of elements in the loss function induced by `weight`.

  A given weight tensor induces different numbers of usable elements in the
  `losses` tensor. The `weight` tensor is broadcast across `losses` for all
  possible dimensions. For example, if `losses` is a tensor of dimension
  [4, 5, 6, 3] and weight is a tensor of size [4, 5], then weight is, in effect,
  tiled to match the size of `losses`. Following this effective tile, the total
  number of present elements is the number of non-zero weights.

  Args:
    losses: A tensor of size [batch_size, d1, ... dN].
    weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
    per_batch: Whether to return the number of elements per batch or as a sum
      total.

  Returns:
    The number of present (non-zero) elements in the losses tensor. If
      `per_batch` is True, the value is returned as a tensor of size
      [batch_size]. Otherwise, a single scalar tensor is returned.
  """
  # To ensure that dims of [2, 1] gets mapped to [2,]
  weight = array_ops.squeeze(weight)

  # If the weight is a scalar, its easy to compute:
  if weight.get_shape().ndims == 0:
    batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
                                                   [0], [1]), [])
    num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
                                 math_ops.to_float(batch_size))
    num_per_batch = math_ops.select(math_ops.equal(weight, 0),
                                    0.0, num_per_batch)
    num_per_batch = math_ops.mul(array_ops.ones(
        array_ops.reshape(batch_size, [1])), num_per_batch)
    return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)

  # First, count the number of nonzero weights:
  if weight.get_shape().ndims >= 1:
    reduction_indices = list(range(1, weight.get_shape().ndims))
    num_nonzero_per_batch = math_ops.reduce_sum(
        math_ops.to_float(math_ops.not_equal(weight, 0)),
        reduction_indices=reduction_indices)

  # Next, determine the number of elements that weight would broadcast to:
  broadcast_dims = array_ops.slice(array_ops.shape(losses),
                                   [weight.get_shape().ndims], [-1])
  num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))

  num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast)
  return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号