def test_distance_calculations():
"""
Tests the implementation of fast distance calculations with the PyTorch
:return:
"""
np.random.seed(1)
# Create random data vectors
A = np.random.randn(10, 23)
B = np.random.randn(5, 23)
sef_dists = fast_distance_matrix(A, B)
assert sef_dists.shape[0] == 10
assert sef_dists.shape[1] == 5
dists = pairwise_distances(A, B)
assert np.sum((sef_dists-dists)*2) < 1e-3
评论列表
文章目录