def plot_encoding(model, sess, title=None, name="encoding",
datasets=("test","validation"), range_=(-4.,4.),
save=True, outdir="./plots", **kwargs):
"""
Plotting utility to encode dataset images in (low dimensional!) latent space
"""
# TODO: check for 2d
title = (title if title else name)
for dataset_name in datasets:
dataset = getattr(model.dataset, dataset_name)
feed_dict = {model.x_in_test: dataset.images}
encoding = sess.run(model.z_dist_info_, feed_dict=feed_dict)
centers = encoding[model.latent_dist.dist_info_keys[0]]
ys, xs = centers.T
plt.figure()
plt.title("round {}: {} in latent space".format(model.counter,
dataset_name))
kwargs = {'alpha': 0.8}
classes = set(dataset.labels)
if classes:
colormap = plt.cm.rainbow(np.linspace(0, 1, len(classes)))
kwargs['c'] = [colormap[i] for i in dataset.labels]
# make room for legend
ax = plt.subplot(111)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
handles = [mpatches.Circle((0,0), label=class_, color=colormap[i])
for i, class_ in enumerate(classes)]
ax.legend(handles=handles, shadow=True, bbox_to_anchor=(1.05, 0.45),
fancybox=True, loc='center left')
plt.scatter(xs, ys, **kwargs)
# map range_ to standard deviations of the target distribution
stddev = model.target_dist.stddev
adjusted_range = (stddev*range_[0], stddev*range_[1])
if range_:
plt.xlim(adjusted_range)
plt.ylim(adjusted_range)
if save:
title = "{}_encoding_{}_round_{}.png".format(
model.datetime, dataset_name, model.counter)
plt.savefig(os.path.join(outdir, title), bbox_inches="tight")
plt.close()
评论列表
文章目录