def visualize_encodings(encodings, file_name=None,
grid=None, skip_every=999, fast=False, fig=None, interactive=False):
encodings = manual_pca(encodings)
if encodings.shape[1] <= 3:
return print_data_only(encodings, file_name, fig=fig, interactive=interactive)
encodings = encodings[0:720]
hessian_euc = dist.squareform(dist.pdist(encodings[0:720], 'euclidean'))
hessian_cos = dist.squareform(dist.pdist(encodings[0:720], 'cosine'))
grid = (3, 4) if grid is None else grid
project_ops = []
n = 2
project_ops.append(("LLE ltsa N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa')))
project_ops.append(("LLE modified N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified')))
project_ops.append(('MDS euclidean N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
project_ops.append(("TSNE 30/2000 N:%d" % n, TSNE(perplexity=30, n_components=n, init='pca', n_iter=2000)))
n = 3
project_ops.append(("LLE ltsa N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa')))
project_ops.append(("LLE modified N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified')))
project_ops.append(('MDS euclidean N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
project_ops.append(('MDS cosine N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
plot_places = []
for i in range(12):
u, v = int(i / (skip_every - 1)), i % (skip_every - 1)
j = v + u * skip_every + 1
plot_places.append(j)
fig = get_figure(fig)
fig.set_size_inches(fig.get_size_inches()[0] * grid[0] / 1.,
fig.get_size_inches()[1] * grid[1] / 2.0)
for i, (name, manifold) in enumerate(project_ops):
is3d = 'N:3' in name
try:
if is3d:
subplot = plt.subplot(grid[0], grid[1], plot_places[i], projection='3d')
else:
subplot = plt.subplot(grid[0], grid[1], plot_places[i])
data_source = encodings if not _needs_hessian(manifold) else \
(hessian_cos if 'cosine' in name else hessian_euc)
projections = manifold.fit_transform(data_source)
scatter(subplot, projections, is3d, _build_radial_colors(len(data_source)))
subplot.set_title(name)
except:
print(name, "Unexpected error: ", sys.exc_info()[0], sys.exc_info()[1] if len(sys.exc_info()) > 1 else '')
visualize_data_same(encodings, grid=grid, places=plot_places[-4:])
if not interactive:
save_fig(file_name, fig)
ut.print_time('visualization finished')
评论列表
文章目录