losses.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def log_quaternion_loss_batch(predictions, labels, name='log_quaternion_batch_loss'):
    """A helper function to compute the error between quaternions.

    Args:
      predictions: A Tensor of size [batch_size, 4].
      labels: A Tensor of size [batch_size, 4].
      params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.

    Returns:
      A Tensor of size [batch_size], denoting the error between the quaternions.
    """
    assertions = []
    assertions.append(
        tf.Assert(tf.reduce_all(tf.less(tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1), 1e-4)),
                  ['The l2 norm of each prediction quaternion vector should be 1.']))
    assertions.append(
        tf.Assert(tf.reduce_all(tf.less(tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
                  ['The l2 norm of each label quaternion vector should be 1.']))
    with tf.name_scope(name):
        with tf.control_dependencies(assertions):
            product = tf.multiply(predictions, labels)
        internal_dot_products = tf.reduce_sum(product, [1])
        logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
    return logcost
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号