target_column.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def training_loss(self, logits, target, features, name="training_loss"):
    """Returns training loss tensor for this head.

    Training loss is different from the loss reported on the tensorboard as we
    should respect the example weights when computing the gradient.

      L = sum_{i} w_{i} * l_{i} / B

    where B is the number of examples in the batch, l_{i}, w_{i} are individual
    losses, and example weight.

    Args:
      logits: logits, a float tensor.
      target: either a tensor for labels or in multihead case, a dict of string
        to target tensor.
      features: features dict.
      name: Op name.

    Returns:
      Loss tensor.
    """
    target = target[self.name] if isinstance(target, dict) else target
    loss_unweighted = self._loss_fn(logits, target)

    weight_tensor = self.get_weight_tensor(features)
    if weight_tensor is None:
      return math_ops.reduce_mean(loss_unweighted, name=name)
    loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
    return math_ops.reduce_mean(loss_weighted, name=name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号