similarity.py 文件源码

python
阅读 56 收藏 0 点赞 0 评论 0

项目:sef 作者: passalis 项目源码 文件源码
def sym_distance_matrix(A, B, eps=1e-18, self_similarity=False):
    """
    Defines the symbolic matrix that contains the distances between the vectors of A and B
    :param A: the first data matrix
    :param B: the second data matrix
    :param self_similarity: zeros the diagonial to improve the stability
    :params eps: the minimum distance between two vectors (set to a very small number to improve stability)
    :return:
    """
    # Compute the squared distances
    AA = torch.sum(A * A, 1).view(-1, 1)
    BB = torch.sum(B * B, 1).view(1, -1)
    AB = torch.mm(A, B.transpose(0, 1))
    D = AA + BB - 2 * AB

    # Zero the diagonial
    if self_similarity:
        D = D.view(-1)
        D[::B.size(0) + 1] = 0
        D = D.view(A.size(0), B.size(0))

    # Return the square root
    D = torch.sqrt(torch.clamp(D, min=eps))

    return D
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号