def compute_weighted_loss(losses, weight=1.0):
"""Computes the weighted loss.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
Returns:
A scalar `Tensor` that returns the weighted loss.
Raises:
ValueError: If the weight is None or the shape is not compatible with the
losses shape or if the number of dimensions (rank) of either losses or
weight is missing.
"""
if weight is None:
raise ValueError("`weight` cannot be None")
input_dtype = losses.dtype
losses = math_ops.to_float(losses)
weight = math_ops.to_float(ops.convert_to_tensor(weight))
if losses.get_shape().ndims is None:
raise ValueError("losses.get_shape().ndims cannot be None")
if weight.get_shape().ndims is None:
raise ValueError("weight.get_shape().ndims cannot be None")
total_loss = _scale_losses(losses, weight)
num_present = _num_present(losses, weight)
mean_loss = _safe_mean(total_loss, num_present)
# convert the result back to the input type
mean_loss = math_ops.cast(mean_loss, input_dtype)
add_loss(mean_loss)
return mean_loss
评论列表
文章目录