def _get_top_k(scores1, scores2, k, max_span_size, support2question):
max_support_length = tf.shape(scores1)[1]
doc_idx, pointer1, topk_scores1 = segment_top_k(scores1, support2question, k)
# [num_questions * beam_size]
doc_idx_flat = tf.reshape(doc_idx, [-1])
pointer_flat1 = tf.reshape(pointer1, [-1])
# [num_questions * beam_size, support_length]
scores_gathered2 = tf.gather(scores2, doc_idx_flat)
if max_span_size < 0:
pointer_flat1, max_span_size = pointer_flat1 + max_span_size + 1, -max_span_size
left_mask = misc.mask_for_lengths(tf.cast(pointer_flat1, tf.int32),
max_support_length, mask_right=False)
right_mask = misc.mask_for_lengths(tf.cast(pointer_flat1 + max_span_size, tf.int32),
max_support_length)
scores_gathered2 = scores_gathered2 + left_mask + right_mask
pointer2 = tf.argmax(scores_gathered2, axis=1, output_type=tf.int32)
topk_score2 = tf.gather_nd(scores2, tf.stack([doc_idx_flat, pointer2], 1))
return doc_idx, pointer1, tf.reshape(pointer2, [-1, k]), topk_scores1 + tf.reshape(topk_score2, [-1, k])
评论列表
文章目录