def _weighted_average_loss_metric_spec(loss_fn, predictoin_key,
label_key, weight_key):
def _streaming_weighted_average_loss(predictions, labels, weights=None):
loss_unweighted = loss_fn(predictions, labels)
if weights is not None:
weights = math_ops.to_float(weights)
_, weighted_average_loss = _loss(loss_unweighted,
weights,
name="eval_loss")
return metrics_lib.streaming_mean(weighted_average_loss)
return metric_spec.MetricSpec(_streaming_weighted_average_loss,
predictoin_key, label_key, weight_key)
评论列表
文章目录