loss_ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def compute_weighted_loss(
    losses, weights=_WEIGHT_SENTINEL, weight=_WEIGHT_SENTINEL):
  """Computes the weighted loss.

  Args:
    losses: A tensor of size [batch_size, d1, ... dN].
    weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
    weight: Deprecated alias for `weights`.

  Returns:
    A scalar `Tensor` that returns the weighted loss.

  Raises:
    ValueError: If `weights` is `None` or the shape is not compatible with
      `losses`, or if the number of dimensions (rank) of either `losses` or
      `weights` is missing.
  """
  weights = _weights(weights, weight)
  losses = ops.convert_to_tensor(losses)
  input_dtype = losses.dtype
  losses = math_ops.to_float(losses)
  weights = math_ops.to_float(ops.convert_to_tensor(weights))

  if losses.get_shape().ndims is None:
    raise ValueError("losses.get_shape().ndims cannot be None")
  weights_shape = weights.get_shape()
  if weights_shape.ndims is None:
    raise ValueError("weight.get_shape().ndims cannot be None")

  if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
    weights = array_ops.squeeze(weights, [-1])

  total_loss = _scale_losses(losses, weights)
  num_present = _num_present(losses, weights)
  mean_loss = _safe_mean(total_loss, num_present)
  # convert the result back to the input type
  mean_loss = math_ops.cast(mean_loss, input_dtype)
  add_loss(mean_loss)
  return mean_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号