def angularErrorTotal(pred, gt, weight, ss, outputChannels=2):
with tf.name_scope("angular_error"):
pred = tf.reshape(pred, (-1, outputChannels))
gt = tf.to_float(tf.reshape(gt, (-1, outputChannels)))
weight = tf.to_float(tf.reshape(weight, (-1, 1)))
ss = tf.to_float(tf.reshape(ss, (-1, 1)))
pred = tf.nn.l2_normalize(pred, 1) * 0.999999
gt = tf.nn.l2_normalize(gt, 1) * 0.999999
errorAngles = tf.acos(tf.reduce_sum(pred * gt, reduction_indices=[1], keep_dims=True))
lossAngleTotal = tf.reduce_sum((tf.abs(errorAngles*errorAngles))*ss*weight)
return lossAngleTotal
评论列表
文章目录