def _scale_losses(losses, weight):
"""Computes the scaled loss.
Args:
losses: A `Tensor` of size [batch_size, d1, ... dN].
weight: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
The `losses` are reduced (tf.reduce_sum) until its dimension matches
that of `weight` at which point the reduced `losses` are element-wise
multiplied by `weight` and a final reduce_sum is computed on the result.
Conceptually, this operation is equivalent to broadcasting (tiling)
`weight` to be the same size as `losses`, performing an element-wise
multiplication, and summing the result.
Returns:
A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
`losses`.
"""
# First, compute the sum of the losses over all elements:
start_index = max(0, weight.get_shape().ndims)
reduction_indices = list(range(start_index, losses.get_shape().ndims))
reduced_losses = math_ops.reduce_sum(losses,
reduction_indices=reduction_indices)
reduced_losses = math_ops.mul(reduced_losses, weight)
return math_ops.reduce_sum(reduced_losses)
评论列表
文章目录