def _wrap_metric(metric):
"""Wraps metrics for mismatched prediction/target types."""
def wrapped(preds, targets):
targets = math_ops.cast(targets, preds.dtype)
return metric(preds, targets)
def wrapped_weights(preds, targets, weights=None):
targets = math_ops.cast(targets, preds.dtype)
if weights is not None:
weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
return metric(preds, targets, weights)
return wrapped_weights if "weights" in _get_metric_args(metric) else wrapped
评论列表
文章目录