pairwise.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:dask-ml 作者: dask 项目源码 文件源码
def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
                                  batch_size=None,
                                  metric_kwargs=None):
    if batch_size is None:
        batch_size = max(X.chunks[0])
    XD = X.to_delayed().flatten().tolist()
    func = delayed(metrics.pairwise_distances_argmin_min, pure=True, nout=2)
    blocks = [func(x, Y, metric=metric, batch_size=batch_size,
                   metric_kwargs=metric_kwargs)
              for x in XD]
    argmins, mins = zip(*blocks)

    argmins = [da.from_delayed(block, (chunksize,), np.int64)
               for block, chunksize in zip(argmins, X.chunks[0])]
    # Scikit-learn seems to always use float64
    mins = [da.from_delayed(block, (chunksize,), 'f8')
            for block, chunksize in zip(mins, X.chunks[0])]
    argmins = da.concatenate(argmins)
    mins = da.concatenate(mins)
    return argmins, mins
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号