def compute_weighted_loss(
losses, weights=_WEIGHT_SENTINEL, weight=_WEIGHT_SENTINEL):
"""Computes the weighted loss.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
weight: Deprecated alias for `weights`.
Returns:
A scalar `Tensor` that returns the weighted loss.
Raises:
ValueError: If `weights` is `None` or the shape is not compatible with
`losses`, or if the number of dimensions (rank) of either `losses` or
`weights` is missing.
"""
weights = _weights(weights, weight)
losses = ops.convert_to_tensor(losses)
input_dtype = losses.dtype
losses = math_ops.to_float(losses)
weights = math_ops.to_float(ops.convert_to_tensor(weights))
if losses.get_shape().ndims is None:
raise ValueError("losses.get_shape().ndims cannot be None")
weights_shape = weights.get_shape()
if weights_shape.ndims is None:
raise ValueError("weight.get_shape().ndims cannot be None")
if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
weights = array_ops.squeeze(weights, [-1])
total_loss = _scale_losses(losses, weights)
num_present = _num_present(losses, weights)
mean_loss = _safe_mean(total_loss, num_present)
# convert the result back to the input type
mean_loss = math_ops.cast(mean_loss, input_dtype)
add_loss(mean_loss)
return mean_loss
评论列表
文章目录