def log_quaternion_loss_batch(predictions, labels, name='log_quaternion_batch_loss'):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
assertions = []
assertions.append(
tf.Assert(tf.reduce_all(tf.less(tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1), 1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(tf.reduce_all(tf.less(tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.name_scope(name):
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
评论列表
文章目录