linear.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _wrap_metric(metric):
  """Wraps metrics for mismatched prediction/target types."""
  def wrapped(preds, targets):
    targets = math_ops.cast(targets, preds.dtype)
    return metric(preds, targets)

  def wrapped_weights(preds, targets, weights=None):
    targets = math_ops.cast(targets, preds.dtype)
    if weights is not None:
      weights = array_ops.reshape(math_ops.to_float(weights), shape=(-1,))
    return metric(preds, targets, weights)

  return wrapped_weights if "weights" in _get_metric_args(metric) else wrapped
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号