def max_sentence_similarity(sentence_input, similarity_matrix):
"""
Parameters
----------
sentence_input: Tensor
Tensor of shape (batch_size, num_sentence_words, rnn_hidden_dim).
similarity_matrix: Tensor
Tensor of shape (batch_size, num_sentence_words, num_sentence_words).
"""
# Shape: (batch_size, passage_len)
def single_instance(inputs):
single_sentence = inputs[0]
argmax_index = inputs[1]
# Shape: (num_sentence_words, rnn_hidden_dim)
return tf.gather(single_sentence, argmax_index)
question_index = tf.arg_max(similarity_matrix, 2)
elems = (sentence_input, question_index)
# Shape: (batch_size, num_sentence_words, rnn_hidden_dim)
return tf.map_fn(single_instance, elems, dtype="float")
评论列表
文章目录