lossFunction.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dwt 作者: min2209 项目源码 文件源码
def angularErrorTotal(pred, gt, weight, ss, outputChannels=2):
    with tf.name_scope("angular_error"):
        pred = tf.reshape(pred, (-1, outputChannels))
        gt = tf.to_float(tf.reshape(gt, (-1, outputChannels)))
        weight = tf.to_float(tf.reshape(weight, (-1, 1)))
        ss = tf.to_float(tf.reshape(ss, (-1, 1)))

        pred = tf.nn.l2_normalize(pred, 1) * 0.999999
        gt = tf.nn.l2_normalize(gt, 1) * 0.999999

        errorAngles = tf.acos(tf.reduce_sum(pred * gt, reduction_indices=[1], keep_dims=True))

        lossAngleTotal = tf.reduce_sum((tf.abs(errorAngles*errorAngles))*ss*weight)

        return lossAngleTotal
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号