components.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sptgraph 作者: epfl-lts2 项目源码 文件源码
def best_shape_clustering(mols, nb_layers, k_range=range(3, 20), train_ratio=0.8, cluster_key='shape_cid'):
    from sklearn.cross_validation import train_test_split
    from sklearn.metrics import silhouette_score

    shape_df = mols['dynamic'].apply(lambda x: temporal_shape(x, nb_layers))
    train_idx, test_idx = train_test_split(shape_df.index.values, train_size=train_ratio)

    train_mat = np.array(list(shape_df[shape_df.index.isin(train_idx)].values))
    full_mat = np.array(list(shape_df.values))

    centroids = None
    labels = None
    best_score = 0
    for k in k_range:
        res = cluster_shapes(train_mat, full_mat, k)
        score = silhouette_score(full_mat, res[1])
        if score > best_score:
            centroids = res[0]
            labels = res[1]
            best_score = score

    mols[cluster_key] = labels
    return mols, centroids
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号