def training_loss(self, logits, target, features, name="training_loss"):
"""Returns training loss tensor for this head.
Training loss is different from the loss reported on the tensorboard as we
should respect the example weights when computing the gradient.
L = sum_{i} w_{i} * l_{i} / B
where B is the number of examples in the batch, l_{i}, w_{i} are individual
losses, and example weight.
Args:
logits: logits, a float tensor.
target: either a tensor for labels or in multihead case, a dict of string
to target tensor.
features: features dict.
name: Op name.
Returns:
Loss tensor.
"""
target = target[self.name] if isinstance(target, dict) else target
loss_unweighted = self._loss_fn(logits, target)
weight_tensor = self.get_weight_tensor(features)
if weight_tensor is None:
return math_ops.reduce_mean(loss_unweighted, name=name)
loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
return math_ops.reduce_mean(loss_weighted, name=name)
评论列表
文章目录