embeddings.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def compute_sims(inputs: mx.nd.NDArray, normalize: bool) -> mx.nd.NDArray:
    """
    Returns a matrix with pair-wise similarity scores between inputs.
    Similarity score is (normalized) Euclidean distance. 'Similarity with self' is masked
    to large negative value.

    :param inputs: NDArray of inputs.
    :param normalize: Whether to normalize to unit-length.
    :return: NDArray with pairwise similarities of same shape as inputs.
    """
    if normalize:
        logger.info("Normalizing embeddings to unit length")
        inputs = mx.nd.L2Normalization(inputs, mode='instance')
    sims = mx.nd.dot(inputs, inputs, transpose_b=True)
    sims_np = sims.asnumpy()
    np.fill_diagonal(sims_np, -9999999.)
    sims = mx.nd.array(sims_np)
    return sims
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号