lmnn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pylmnn 作者: johny-c 项目源码 文件源码
def _find_impostors_batch(x1, x2, t1, t2, return_dist=False, batch_size=500):
        """Find impostor pairs in chunks to avoid large memory usage

        Parameters
        ----------
        x1 : array_like
            An array of transformed data samples with shape (n_samples, n_features).
        x2 : array_like
            An array of transformed data samples with shape (m_samples, n_features) where m_samples < n_samples.
        t1 : array_like
            An array of distances to the margins with shape (n_samples,).
        t2 : array_like
            An array of distances to the margins with shape (m_samples,).
        batch_size : int (Default value = 500)
            The size of each chunk of x1 to compute distances to.
        return_dist : bool (Default value = False)
            Whether to return the distances to the impostors.

        Returns
        -------
        tuple: (array_like, array_like, [array_like])

            imp1 : array_like
                An array of sample indices with shape (n_impostors,).
            imp2 : array_like
                An array of sample indices that violate a margin with shape (n_impostors,).
            dist : array_like, optional
                An array of pairwise distances of (imp1, imp2) with shape (n_impostors,).

        """

        n, m = len(t1), len(t2)
        imp1, imp2, dist = [], [], []
        for chunk in gen_batches(n, batch_size):
            dist_out_in = euclidean_distances(x1[chunk], x2, squared=True)
            i1, j1 = np.where(dist_out_in < t1[chunk, None])
            i2, j2 = np.where(dist_out_in < t2[None, :])
            if len(i1):
                imp1.extend(i1 + chunk.start)
                imp2.extend(j1)
                if return_dist:
                    dist.extend(dist_out_in[i1, j1])
            if len(i2):
                imp1.extend(i2 + chunk.start)
                imp2.extend(j2)
                if return_dist:
                    dist.extend(dist_out_in[i2, j2])

        if return_dist:
            return imp1, imp2, dist
        else:
            return imp1, imp2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号