cluster.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:StarData 作者: TorchCraft 项目源码 文件源码
def _cluster(infn, outfn):
    if path.exists(outfn + '.lock') or path.exists(outfn + '.txt'):
        return
    open(outfn + '.lock', 'w').close()
    print("doing " + infn)
    data, xyt, transform, untransform, valid, maxes = parse_file(infn)
    if not valid:
        return
    bandwidth = args.bandwidth
    if bandwidth < 0:
        bandwidth = estimate_bandwidth(xyt, quantile=0.2, n_samples=500)

    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    ms.fit(xyt)
    centers = ms.cluster_centers_
    radius = bandwidth
    centers = untransform(centers)
    radius = untransform(radius)
    xyt = untransform(xyt)

    labels = ms.labels_
    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)
    few = np.bincount(labels) < args.min_deaths
    extract_battles(outfn + '.txt', data, ms, maxes, xyt, transform, untransform)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
    for k, too_few in zip(range(n_clusters_), few):
        my_members = labels == k
        cluster_center = centers[k]
        if too_few:
            col = 'black'
        else:
            col = next(colors)
            ax.scatter(cluster_center[0], cluster_center[1], cluster_center[2], 'o', c=col, s=100)
        ax.scatter(xyt[my_members, 0], xyt[my_members, 1], xyt[my_members, 2], c=col)
    if args.show:
        plt.show()
    plt.savefig(outfn + ".png")
    plt.close(fig)
    try:
        os.remove(outfn + '.lock')
    except:
        pass
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号