def loss(self, logits, target, features):
"""Returns loss tensor for this head.
The loss returned is the weighted average.
L = sum_{i} w_{i} * l_{i} / sum_{i} w_{i}
Args:
logits: logits, a float tensor.
target: either a tensor for labels or in multihead case, a dict of string
to target tensor.
features: features dict.
Returns:
Loss tensor.
"""
target = target[self.name] if isinstance(target, dict) else target
loss_unweighted = self._loss_fn(logits, target)
weight_tensor = self.get_weight_tensor(features)
if weight_tensor is None:
return math_ops.reduce_mean(loss_unweighted, name="loss")
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
return math_ops.div(
math_ops.reduce_sum(loss_weighted),
math_ops.to_float(math_ops.reduce_sum(weight_tensor)),
name="loss")
评论列表
文章目录