def _weighted_loss(loss, weight):
"""Returns cumulative weighted loss."""
unweighted_loss = array_ops.reshape(loss, shape=(-1,))
weighted_loss = math_ops.mul(unweighted_loss,
array_ops.reshape(
weight, shape=(-1,)))
return weighted_loss
评论列表
文章目录