test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pdist 作者: oliviaguest 项目源码 文件源码
def test_mean_of_distances(self):
        """Test the mean of distances calculation (and the sum)."""
        X = np.array([[0.3, 0.4],
                      [0.1, 4.0],
                      [2.0, 1.0],
                      [0.0, 0.5]])
        counts = np.array([3, 2, 1, 2])
        scipy_X = []
        for c, count in enumerate(counts):
            for i in range(count):
                scipy_X.append(X[c])

        # SciPy:
        Y = pdist(scipy_X, metric=cdist)
        scipy_N = np.sum(counts)
        N_unique_pairs = scipy_N * (scipy_N - 1.0) / 2.0
        scipy_mean = Y.mean()
        self.assertTrue(Y.shape[0] == N_unique_pairs)
        self.assertTrue(scipy_mean == (np.sum(Y) / N_unique_pairs))

        # C & Cython:
        c_mean = c_mean_dist(X, counts)
        self.assertTrue(np.isclose(c_mean, scipy_mean))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号