def pdist(x: T.FloatTensor, y: T.FloatTensor) -> T.FloatTensor:
"""
Compute the pairwise distance matrix between the rows of x and y.
Args:
x (tensor (num_samples_1, num_units))
y (tensor (num_samples_2, num_units))
Returns:
tensor (num_samples_1, num_samples_2)
"""
inner = dot(x, transpose(y))
x_mag = norm(x, axis=1) ** 2
y_mag = norm(y, axis=1) ** 2
squared = add(unsqueeze(y_mag, axis=0), add(unsqueeze(x_mag, axis=1), -2*inner))
return torch.sqrt(clip(squared, a_min=0))
评论列表
文章目录