plot.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:AAE-tensorflow 作者: gitmatti 项目源码 文件源码
def plot_encoding(model, sess, title=None, name="encoding",
                  datasets=("test","validation"), range_=(-4.,4.),
                  save=True, outdir="./plots", **kwargs):
    """
    Plotting utility to encode dataset images in (low dimensional!) latent space
    """
    # TODO: check for 2d
    title = (title if title else name)
    for dataset_name in datasets:
        dataset = getattr(model.dataset, dataset_name)
        feed_dict = {model.x_in_test: dataset.images}
        encoding = sess.run(model.z_dist_info_, feed_dict=feed_dict)
        centers = encoding[model.latent_dist.dist_info_keys[0]]
        ys, xs = centers.T

        plt.figure()
        plt.title("round {}: {} in latent space".format(model.counter,
                                                        dataset_name))
        kwargs = {'alpha': 0.8}

        classes = set(dataset.labels)
        if classes:
            colormap = plt.cm.rainbow(np.linspace(0, 1, len(classes)))
            kwargs['c'] = [colormap[i] for i in dataset.labels]
            # make room for legend
            ax = plt.subplot(111)
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            handles = [mpatches.Circle((0,0), label=class_, color=colormap[i])
                        for i, class_ in enumerate(classes)]
            ax.legend(handles=handles, shadow=True, bbox_to_anchor=(1.05, 0.45),
                      fancybox=True, loc='center left')

        plt.scatter(xs, ys, **kwargs)

        # map range_ to standard deviations of the target distribution
        stddev = model.target_dist.stddev
        adjusted_range = (stddev*range_[0], stddev*range_[1])

        if range_:
            plt.xlim(adjusted_range)
            plt.ylim(adjusted_range)

        if save:
            title = "{}_encoding_{}_round_{}.png".format(
                model.datetime, dataset_name, model.counter)
            plt.savefig(os.path.join(outdir, title), bbox_inches="tight")
        plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号