def sym_distance_matrix(A, B, eps=1e-18, self_similarity=False):
"""
Defines the symbolic matrix that contains the distances between the vectors of A and B
:param A: the first data matrix
:param B: the second data matrix
:param self_similarity: zeros the diagonial to improve the stability
:params eps: the minimum distance between two vectors (set to a very small number to improve stability)
:return:
"""
# Compute the squared distances
AA = torch.sum(A * A, 1).view(-1, 1)
BB = torch.sum(B * B, 1).view(1, -1)
AB = torch.mm(A, B.transpose(0, 1))
D = AA + BB - 2 * AB
# Zero the diagonial
if self_similarity:
D = D.view(-1)
D[::B.size(0) + 1] = 0
D = D.view(A.size(0), B.size(0))
# Return the square root
D = torch.sqrt(torch.clamp(D, min=eps))
return D
评论列表
文章目录