cluster_tools.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:SUPPA 作者: comprna 项目源码 文件源码
def calculate_cluster_scores(x, cluster_labels, output):

    with open("%s_scores.log" % output, "w+") as fh:
        # Filter out singleton "cluster" (labeled as -1)
        filtered_x, filtered_cluster_labels, singletons = ([] for _ in range(3))
        cluster_groups = defaultdict(list)
        for vec, lab in zip(x, cluster_labels):
            if not lab == -1:
                filtered_x.append(vec)
                filtered_cluster_labels.append(lab)

                cluster_groups[lab].append(vec)
            else:
                singletons.append(vec)

        ln = "Number of clustered events: %d/%d (%f%%)\n" % (len(filtered_x), len(filtered_x)+len(singletons),
                                                           (len(filtered_x)/(len(filtered_x)+len(singletons)))*100)
        print(ln.strip("\n"))
        fh.write(ln)

        for group in cluster_groups:
                n_events = len(cluster_groups[group])
                ln = "Cluster %d contains %d events\n" % (group, n_events)
                print(ln.strip("\n"))
                fh.write(ln)

        rmsstd_scores = []
        for group in cluster_groups:
            rmsstd = calculate_rmsstd(np.array(cluster_groups[group]))
            ln = "The RMSSTD score for cluster %d is %f\n" % (group, rmsstd)
            print(ln.strip("\n"))
            fh.write(ln)

            rmsstd_scores.append(rmsstd)

        try:
            silhouette_avg = silhouette_score(np.array(filtered_x), np.array(filtered_cluster_labels))
            ln = "The average silhouette score is : %f\n" % silhouette_avg
            print(ln.strip("\n"))
            fh.write(ln)
        except:
            silhouette_avg = float("nan")
            ln = "Impossible to calculate silhouette score. Only 1 cluster group identified.\n"
            print(ln.strip("\n"))
            fh.write(ln)

    return silhouette_avg, rmsstd_scores
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号