numerics.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:vinci 作者: Phylliade 项目源码 文件源码
def huber_loss(y_true, y_pred, clip_value):
    # Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
    # https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
    # for details.
    assert clip_value > 0.

    x = y_true - y_pred
    if np.isinf(clip_value):
        # Spacial case for infinity since Tensorflow does have problems
        # if we compare `K.abs(x) < np.inf`.
        return .5 * tf.square(x)

    condition = tf.abs(x) < clip_value
    squared_loss = .5 * tf.square(x)
    linear_loss = clip_value * (tf.abs(x) - .5 * clip_value)

    return tf.where(condition, squared_loss, linear_loss)  # condition, true, false
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号