head.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _weighted_average_loss_metric_spec(loss_fn, predictoin_key,
                                       label_key, weight_key):
  def _streaming_weighted_average_loss(predictions, labels, weights=None):
    loss_unweighted = loss_fn(predictions, labels)
    if weights is not None:
      weights = math_ops.to_float(weights)
    _, weighted_average_loss = _loss(loss_unweighted,
                                     weights,
                                     name="eval_loss")
    return metrics_lib.streaming_mean(weighted_average_loss)
  return metric_spec.MetricSpec(_streaming_weighted_average_loss,
                                predictoin_key, label_key, weight_key)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号