network.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:sakmapper 作者: szairis 项目源码 文件源码
def gap(data, refs=None, nrefs=20, ks=range(1,11), method=None):
    shape = data.shape
    if refs is None:
        tops = data.max(axis=0)
        bots = data.min(axis=0)
        dists = scipy.matrix(scipy.diag(tops-bots))

        rands = scipy.random.random_sample(size=(shape[0], shape[1], nrefs))
        for i in range(nrefs):
            rands[:, :, i] = rands[:, :, i]*dists+bots
    else:
        rands = refs
    gaps = scipy.zeros((len(ks),))
    for (i, k) in enumerate(ks):
        g1 = method(n_clusters=k).fit(data)
        (kmc, kml) = (g1.cluster_centers_, g1.labels_)
        disp = sum([euclidean(data[m, :], kmc[kml[m], :]) for m in range(shape[0])])

        refdisps = scipy.zeros((rands.shape[2],))
        for j in range(rands.shape[2]):
            g2 = method(n_clusters=k).fit(rands[:, :, j])
            (kmc, kml) = (g2.cluster_centers_, g2.labels_)
            refdisps[j] = sum([euclidean(rands[m, :, j], kmc[kml[m],:]) for m in range(shape[0])])
        gaps[i] = scipy.log(scipy.mean(refdisps))-scipy.log(disp)
    return gaps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号