answer_layer.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def compute_spans(start_scores, end_scores, answer2support, is_eval, support2question,
                  beam_size=1, max_span_size=10000, correct_start=None):
    max_support_length = tf.shape(start_scores)[1]
    _, _, num_doc_per_question = tf.unique_with_counts(support2question)
    offsets = tf.cumsum(num_doc_per_question, exclusive=True)
    doc_idx_for_support = tf.range(tf.shape(support2question)[0]) - tf.gather(offsets, support2question)

    def train():
        gathered_end_scores = tf.gather(end_scores, answer2support)
        gathered_start_scores = tf.gather(start_scores, answer2support)

        if correct_start is not None:
            # assuming we know the correct start we only consider ends after that
            left_mask = misc.mask_for_lengths(tf.cast(correct_start, tf.int32), max_support_length, mask_right=False)
            gathered_end_scores = gathered_end_scores + left_mask

        predicted_start_pointer = tf.argmax(gathered_start_scores, axis=1, output_type=tf.int32)
        predicted_end_pointer = tf.argmax(gathered_end_scores, axis=1, output_type=tf.int32)

        return (start_scores, end_scores,
                tf.gather(doc_idx_for_support, answer2support), predicted_start_pointer, predicted_end_pointer)

    def eval():
        # we collect spans for top k starts and top k ends and select the top k from those top 2k
        doc_idx1, start_pointer1, end_pointer1, span_score1 = _get_top_k(
            start_scores, end_scores, beam_size, max_span_size, support2question)
        doc_idx2, end_pointer2, start_pointer2, span_score2 = _get_top_k(
            end_scores, start_scores, beam_size, -max_span_size, support2question)

        doc_idx = tf.concat([doc_idx1, doc_idx2], 1)
        start_pointer = tf.concat([start_pointer1, start_pointer2], 1)
        end_pointer = tf.concat([end_pointer1, end_pointer2], 1)
        span_score = tf.concat([span_score1, span_score2], 1)

        _, idx = tf.nn.top_k(span_score, beam_size)

        r = tf.range(tf.shape(span_score)[0], dtype=tf.int32)
        r = tf.reshape(tf.tile(tf.expand_dims(r, 1), [1, beam_size]), [-1, 1])

        idx = tf.concat([r, tf.reshape(idx, [-1, 1])], 1)
        doc_idx = tf.gather_nd(doc_idx, idx)
        start_pointer = tf.gather_nd(start_pointer, idx)
        end_pointer = tf.gather_nd(end_pointer, idx)

        return (start_scores, end_scores, tf.gather(doc_idx_for_support, doc_idx), start_pointer, end_pointer)

    return tf.cond(is_eval, eval, train)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号