matching.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:paraphrase-id-tensorflow 作者: nelson-liu 项目源码 文件源码
def mask_similarity_matrix(similarity_matrix, mask_a, mask_b):
    """
    Given the mask of the two sentences, apply the mask to the similarity
    matrix.

    Parameters
    ----------
    similarity_matrix: Tensor
        Tensor of shape (batch_size, num_sentence_words, num_sentence_words).

    mask_a: Tensor
        Tensor of shape (batch_size, num_sentence_words). This mask should
        correspond to the first vector (v1) used to calculate the similarity
        matrix.

    mask_b: Tensor
        Tensor of shape (batch_size, num_sentence_words). This mask should
        correspond to the second vector (v2) used to calculate the similarity
        matrix.
    """
    similarity_matrix = tf.multiply(similarity_matrix,
                                    tf.expand_dims(tf.cast(mask_a, "float"), 1))
    similarity_matrix = tf.multiply(similarity_matrix,
                                    tf.expand_dims(tf.cast(mask_b, "float"), 2))
    return similarity_matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号