def _weighted_loss(loss, weight_tensor):
unweighted_loss = array_ops.reshape(loss, shape=(-1,))
weighted_loss = math_ops.mul(
unweighted_loss, array_ops.reshape(weight_tensor, shape=(-1,)))
return math_ops.div(
math_ops.reduce_sum(weighted_loss),
math_ops.to_float(math_ops.reduce_sum(weight_tensor)),
name="loss")
评论列表
文章目录