lossFunction.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:dwt 作者: min2209 项目源码 文件源码
def exceedingAngleThreshold(pred, gt, ss, threshold, outputChannels=2):
    with tf.name_scope("angular_error"):
        pred = tf.reshape(pred, (-1, outputChannels))
        gt = tf.to_float(tf.reshape(gt, (-1, outputChannels)))
        ss = tf.to_float(tf.reshape(ss, (-1, 1)))

        pred = tf.nn.l2_normalize(pred, 1) * 0.999999
        gt = tf.nn.l2_normalize(gt, 1) * 0.999999

        errorAngles = tf.acos(tf.reduce_sum(pred * gt, reduction_indices=[1], keep_dims=True)) * ss

        exceedCount = tf.reduce_sum(tf.to_float(tf.less(threshold/180*3.14159, errorAngles)))

        return exceedCount
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号