def _scale_losses(losses, weights):
"""Computes the scaled loss.
Args:
losses: A `Tensor` of size [batch_size, d1, ... dN].
weights: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
The `losses` are reduced (tf.reduce_sum) until its dimension matches
that of `weights` at which point the reduced `losses` are element-wise
multiplied by `weights` and a final reduce_sum is computed on the result.
Conceptually, this operation is equivalent to broadcasting (tiling)
`weights` to be the same size as `losses`, performing an element-wise
multiplication, and summing the result.
Returns:
A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
`losses`.
"""
# First, compute the sum of the losses over all elements:
start_index = max(0, weights.get_shape().ndims)
reduction_indices = list(range(start_index, losses.get_shape().ndims))
reduced_losses = math_ops.reduce_sum(losses,
reduction_indices=reduction_indices)
reduced_losses = math_ops.multiply(reduced_losses, weights)
return math_ops.reduce_sum(reduced_losses)
loss_ops.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录