def _weighted_loss(self, loss, weight_tensor):
"""Returns cumulative weighted loss."""
unweighted_loss = array_ops.reshape(loss, shape=(-1,))
weighted_loss = math_ops.mul(unweighted_loss,
array_ops.reshape(
weight_tensor, shape=(-1,)))
return weighted_loss
评论列表
文章目录