network.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:sakmapper 作者: szairis 项目源码 文件源码
def mapper_graph(df, lens_data=None, lens='pca', resolution=10, gain=0.5, equalize=True, clust='kmeans', stat='db',
                 max_K=5):
    """
    input: N x n_dim image of of raw data under lens function, as a dataframe
    output: (undirected graph, list of node contents, dictionary of patches)
    """
    if lens_data is None:
        lens_data = apply_lens(df, lens=lens)

    patch_clusterings = {}
    counter = 0
    patches = covering_patches(lens_data, resolution=resolution, gain=gain, equalize=equalize)
    for key, patch in patches.items():
        if len(patch) > 0:
            patch_clusterings[key] = optimal_clustering(df, patch, method=clust, statistic=stat, max_K=max_K)
            counter += 1
    print 'total of {} patches required clustering'.format(counter)

    all_clusters = []
    for key in patch_clusterings:
        all_clusters += patch_clusterings[key]
    num_nodes = len(all_clusters)
    print 'this implies {} nodes in the mapper graph'.format(num_nodes)

    A = np.zeros((num_nodes, num_nodes))
    for i in range(num_nodes):
        for j in range(i):
            overlap = set(all_clusters[i]).intersection(set(all_clusters[j]))
            if len(overlap) > 0:
                A[i, j] = 1
                A[j, i] = 1

    G = nx.from_numpy_matrix(A)
    total = []
    all_clusters_new = []
    mapping = {}
    cont = 0
    for m in all_clusters:
        total += m
    for n, m in enumerate(all_clusters):
        if len(m) == 1 and total.count(m) > 1:
            G.remove_node(n)
        else:
            all_clusters_new.append(m)
            mapping[n] = cont
            cont += 1
    H = nx.relabel_nodes(G, mapping)
    return H, all_clusters_new, patches
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号