def exceedingAngleThreshold(pred, gt, ss, threshold, outputChannels=2):
with tf.name_scope("angular_error"):
pred = tf.reshape(pred, (-1, outputChannels))
gt = tf.to_float(tf.reshape(gt, (-1, outputChannels)))
ss = tf.to_float(tf.reshape(ss, (-1, 1)))
pred = tf.nn.l2_normalize(pred, 1) * 0.999999
gt = tf.nn.l2_normalize(gt, 1) * 0.999999
errorAngles = tf.acos(tf.reduce_sum(pred * gt, reduction_indices=[1], keep_dims=True)) * ss
exceedCount = tf.reduce_sum(tf.to_float(tf.less(threshold/180*3.14159, errorAngles)))
return exceedCount
评论列表
文章目录