def _rescale_eval_loss(loss, weights):
"""Rescales evaluation loss according to the given weights.
The rescaling is needed because in the training loss weights are not
considered in the denominator, whereas for the evaluation loss we should
divide by the sum of weights.
The rescaling factor is:
R = sum_{i} 1 / sum_{i} w_{i}
Args:
loss: the scalar weighted loss.
weights: weight coefficients. Either a scalar, or a `Tensor` of shape
[batch_size].
Returns:
The given loss multiplied by the rescaling factor.
"""
rescaling_factor = math_ops.reduce_mean(weights)
return math_ops.div(loss, rescaling_factor)
评论列表
文章目录