def mask_similarity_matrix(similarity_matrix, mask_a, mask_b):
"""
Given the mask of the two sentences, apply the mask to the similarity
matrix.
Parameters
----------
similarity_matrix: Tensor
Tensor of shape (batch_size, num_sentence_words, num_sentence_words).
mask_a: Tensor
Tensor of shape (batch_size, num_sentence_words). This mask should
correspond to the first vector (v1) used to calculate the similarity
matrix.
mask_b: Tensor
Tensor of shape (batch_size, num_sentence_words). This mask should
correspond to the second vector (v2) used to calculate the similarity
matrix.
"""
similarity_matrix = tf.multiply(similarity_matrix,
tf.expand_dims(tf.cast(mask_a, "float"), 1))
similarity_matrix = tf.multiply(similarity_matrix,
tf.expand_dims(tf.cast(mask_b, "float"), 2))
return similarity_matrix
评论列表
文章目录