xqa.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def xqa_crossentropy_loss(start_scores, end_scores, answer_span, answer2support, support2question, use_sum=True):
    """Very common XQA loss function."""
    num_questions = tf.reduce_max(support2question) + 1

    start, end = answer_span[:, 0], answer_span[:, 1]

    start_probs = segment_softmax(start_scores, support2question)
    start_probs = tf.gather_nd(start_probs, tf.stack([answer2support, start], 1))

    # only start probs are normalized on multi-paragraph, end probs conditioned on start only on per support level
    num_answers = tf.shape(answer_span)[0]
    is_aligned = tf.equal(tf.shape(end_scores)[0], num_answers)
    end_probs = tf.cond(
        is_aligned,
        lambda: tf.gather_nd(tf.nn.softmax(end_scores), tf.stack([tf.range(num_answers, dtype=tf.int32), end], 1)),
        lambda: tf.gather_nd(segment_softmax(end_scores, support2question), tf.stack([answer2support, end], 1))
    )

    answer2question = tf.gather(support2question, answer2support)
    # compute losses individually
    if use_sum:
        span_probs = tf.unsorted_segment_sum(
            start_probs, answer2question, num_questions) * tf.unsorted_segment_sum(
            end_probs, answer2question, num_questions)
    else:
        span_probs = tf.unsorted_segment_max(
            start_probs, answer2question, num_questions) * tf.unsorted_segment_max(
            end_probs, answer2question, num_questions)

    return -tf.reduce_mean(tf.log(tf.maximum(1e-6, span_probs + 1e-6)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号