def mapper_graph(df, lens_data=None, lens='pca', resolution=10, gain=0.5, equalize=True, clust='kmeans', stat='db',
max_K=5):
"""
input: N x n_dim image of of raw data under lens function, as a dataframe
output: (undirected graph, list of node contents, dictionary of patches)
"""
if lens_data is None:
lens_data = apply_lens(df, lens=lens)
patch_clusterings = {}
counter = 0
patches = covering_patches(lens_data, resolution=resolution, gain=gain, equalize=equalize)
for key, patch in patches.items():
if len(patch) > 0:
patch_clusterings[key] = optimal_clustering(df, patch, method=clust, statistic=stat, max_K=max_K)
counter += 1
print 'total of {} patches required clustering'.format(counter)
all_clusters = []
for key in patch_clusterings:
all_clusters += patch_clusterings[key]
num_nodes = len(all_clusters)
print 'this implies {} nodes in the mapper graph'.format(num_nodes)
A = np.zeros((num_nodes, num_nodes))
for i in range(num_nodes):
for j in range(i):
overlap = set(all_clusters[i]).intersection(set(all_clusters[j]))
if len(overlap) > 0:
A[i, j] = 1
A[j, i] = 1
G = nx.from_numpy_matrix(A)
total = []
all_clusters_new = []
mapping = {}
cont = 0
for m in all_clusters:
total += m
for n, m in enumerate(all_clusters):
if len(m) == 1 and total.count(m) > 1:
G.remove_node(n)
else:
all_clusters_new.append(m)
mapping[n] = cont
cont += 1
H = nx.relabel_nodes(G, mapping)
return H, all_clusters_new, patches
评论列表
文章目录