loss_ops.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def compute_weighted_loss(losses, weight=1.0):
  """Computes the weighted loss.

  Args:
    losses: A tensor of size [batch_size, d1, ... dN].
    weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.

  Returns:
    A scalar `Tensor` that returns the weighted loss.

  Raises:
    ValueError: If the weight is None or the shape is not compatible with the
      losses shape or if the number of dimensions (rank) of either losses or
      weight is missing.
  """
  if weight is None:
    raise ValueError("`weight` cannot be None")
  input_dtype = losses.dtype
  losses = math_ops.to_float(losses)
  weight = math_ops.to_float(ops.convert_to_tensor(weight))

  if losses.get_shape().ndims is None:
    raise ValueError("losses.get_shape().ndims cannot be None")
  if weight.get_shape().ndims is None:
    raise ValueError("weight.get_shape().ndims cannot be None")

  total_loss = _scale_losses(losses, weight)
  num_present = _num_present(losses, weight)
  mean_loss = _safe_mean(total_loss, num_present)
  # convert the result back to the input type
  mean_loss = math_ops.cast(mean_loss, input_dtype)
  add_loss(mean_loss)
  return mean_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号