visualization.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:TensorFlow_DCIGN 作者: yselivonchyk 项目源码 文件源码
def visualize_encodings(encodings, file_name=None,
                        grid=None, skip_every=999, fast=False, fig=None, interactive=False):
  encodings = manual_pca(encodings)
  if encodings.shape[1] <= 3:
    return print_data_only(encodings, file_name, fig=fig, interactive=interactive)

  encodings = encodings[0:720]
  hessian_euc = dist.squareform(dist.pdist(encodings[0:720], 'euclidean'))
  hessian_cos = dist.squareform(dist.pdist(encodings[0:720], 'cosine'))
  grid = (3, 4) if grid is None else grid
  project_ops = []

  n = 2
  project_ops.append(("LLE ltsa       N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa')))
  project_ops.append(("LLE modified   N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified')))
  project_ops.append(('MDS euclidean  N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
  project_ops.append(("TSNE 30/2000   N:%d" % n, TSNE(perplexity=30, n_components=n, init='pca', n_iter=2000)))
  n = 3
  project_ops.append(("LLE ltsa       N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='ltsa')))
  project_ops.append(("LLE modified   N:%d" % n, mn.LocallyLinearEmbedding(10, n, method='modified')))
  project_ops.append(('MDS euclidean  N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))
  project_ops.append(('MDS cosine     N:%d' % n, mn.MDS(n, max_iter=300, n_init=1, dissimilarity='precomputed')))

  plot_places = []
  for i in range(12):
    u, v = int(i / (skip_every - 1)), i % (skip_every - 1)
    j = v + u * skip_every + 1
    plot_places.append(j)

  fig = get_figure(fig)
  fig.set_size_inches(fig.get_size_inches()[0] * grid[0] / 1.,
                      fig.get_size_inches()[1] * grid[1] / 2.0)

  for i, (name, manifold) in enumerate(project_ops):
    is3d = 'N:3' in name

    try:
      if is3d:
        subplot = plt.subplot(grid[0], grid[1], plot_places[i], projection='3d')
      else:
        subplot = plt.subplot(grid[0], grid[1], plot_places[i])

      data_source = encodings if not _needs_hessian(manifold) else \
        (hessian_cos if 'cosine' in name else hessian_euc)
      projections = manifold.fit_transform(data_source)
      scatter(subplot, projections, is3d, _build_radial_colors(len(data_source)))
      subplot.set_title(name)
    except:
      print(name, "Unexpected error: ", sys.exc_info()[0], sys.exc_info()[1] if len(sys.exc_info()) > 1 else '')

  visualize_data_same(encodings, grid=grid, places=plot_places[-4:])
  if not interactive:
    save_fig(file_name, fig)
  ut.print_time('visualization finished')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号