loss_ops.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _scale_losses(losses, weight):
  """Computes the scaled loss.

  Args:
    losses: A `Tensor` of size [batch_size, d1, ... dN].
    weight: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
      The `losses` are reduced (tf.reduce_sum) until its dimension matches
      that of `weight` at which point the reduced `losses` are element-wise
      multiplied by `weight` and a final reduce_sum is computed on the result.
      Conceptually, this operation is equivalent to broadcasting (tiling)
      `weight` to be the same size as `losses`, performing an element-wise
      multiplication, and summing the result.

  Returns:
    A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
      `losses`.
  """
  # First, compute the sum of the losses over all elements:
  start_index = max(0, weight.get_shape().ndims)
  reduction_indices = list(range(start_index, losses.get_shape().ndims))
  reduced_losses = math_ops.reduce_sum(losses,
                                       reduction_indices=reduction_indices)
  reduced_losses = math_ops.mul(reduced_losses, weight)
  return math_ops.reduce_sum(reduced_losses)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号