def compute_spans(start_scores, end_scores, answer2support, is_eval, support2question,
beam_size=1, max_span_size=10000, correct_start=None):
max_support_length = tf.shape(start_scores)[1]
_, _, num_doc_per_question = tf.unique_with_counts(support2question)
offsets = tf.cumsum(num_doc_per_question, exclusive=True)
doc_idx_for_support = tf.range(tf.shape(support2question)[0]) - tf.gather(offsets, support2question)
def train():
gathered_end_scores = tf.gather(end_scores, answer2support)
gathered_start_scores = tf.gather(start_scores, answer2support)
if correct_start is not None:
# assuming we know the correct start we only consider ends after that
left_mask = misc.mask_for_lengths(tf.cast(correct_start, tf.int32), max_support_length, mask_right=False)
gathered_end_scores = gathered_end_scores + left_mask
predicted_start_pointer = tf.argmax(gathered_start_scores, axis=1, output_type=tf.int32)
predicted_end_pointer = tf.argmax(gathered_end_scores, axis=1, output_type=tf.int32)
return (start_scores, end_scores,
tf.gather(doc_idx_for_support, answer2support), predicted_start_pointer, predicted_end_pointer)
def eval():
# we collect spans for top k starts and top k ends and select the top k from those top 2k
doc_idx1, start_pointer1, end_pointer1, span_score1 = _get_top_k(
start_scores, end_scores, beam_size, max_span_size, support2question)
doc_idx2, end_pointer2, start_pointer2, span_score2 = _get_top_k(
end_scores, start_scores, beam_size, -max_span_size, support2question)
doc_idx = tf.concat([doc_idx1, doc_idx2], 1)
start_pointer = tf.concat([start_pointer1, start_pointer2], 1)
end_pointer = tf.concat([end_pointer1, end_pointer2], 1)
span_score = tf.concat([span_score1, span_score2], 1)
_, idx = tf.nn.top_k(span_score, beam_size)
r = tf.range(tf.shape(span_score)[0], dtype=tf.int32)
r = tf.reshape(tf.tile(tf.expand_dims(r, 1), [1, beam_size]), [-1, 1])
idx = tf.concat([r, tf.reshape(idx, [-1, 1])], 1)
doc_idx = tf.gather_nd(doc_idx, idx)
start_pointer = tf.gather_nd(start_pointer, idx)
end_pointer = tf.gather_nd(end_pointer, idx)
return (start_scores, end_scores, tf.gather(doc_idx_for_support, doc_idx), start_pointer, end_pointer)
return tf.cond(is_eval, eval, train)
评论列表
文章目录