mnplots.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:clust 作者: BaselAbujamous 项目源码 文件源码
def mnplotsdistancethreshold(dists, method='bimodal', returnmodel=False):
    distsloc = np.array(dists).reshape(-1, 1)
    if method == 'bimodal':
        GM = skmix.GaussianMixture(n_components=2)
        GM.fit(distsloc)
        if len(dists) == 1:
            labels = [1]
        else:
            labels = GM.predict(distsloc)
            labels = np.nonzero(labels == labels[0])[0]
        if returnmodel:
            return (labels, GM)
        else:
            return labels
    elif method == 'largestgap' or method == 'largest_gap':
        if len(dists) == 1:
            labels = [1]
        else:
            gaps = np.subtract(dists[1:], dists[0:-1])
            wgaps = np.multiply(gaps, np.arange(len(gaps), 0, -1))  # weight gaps (higher weight for first clusters)
            labels = np.arange(0, np.argmax(wgaps)+1)
        return labels
    else:
        raise ValueError('Invalid method submitted to mnplotsdistancethreshold. '
                         'Use either ''bimodal'' or ''largestgap''')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号