loss.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow-litterbox 作者: rwightman 项目源码 文件源码
def _compute_huber(predictions, labels, delta=1.0):
    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    predictions = tf.to_float(predictions)
    labels = tf.to_float(labels)
    delta = tf.to_float(delta)

    diff = predictions - labels
    diff_abs = tf.abs(diff)
    delta_fact = 0.5 * tf.square(delta)
    condition = tf.less(diff_abs, delta)
    left_opt = 0.5 * tf.square(diff)
    right_opt = delta * diff_abs - delta_fact
    losses_val = tf.select(condition, left_opt, right_opt)
    return losses_val


# Returns non-reduced tensor of unweighted losses with batch dimension matching inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号