def visualize_10_2d_gaussian_prior(n_z, y_label, visualization_dir=None):
z_batch = sample_z_from_n_2d_gaussian_mixture(len(y_label), n_z, y_label, 10, False)
z_batch = z_batch.data
fig = pylab.gcf()
fig.set_size_inches(15, 12)
pylab.clf()
colors = ["#2103c8", "#0e960e", "#e40402", "#05aaa8", "#ac02ab", "#aba808", "#151515", "#94a169", "#bec9cd",
"#6a6551"]
for n in xrange(z_batch.shape[0]):
result = pylab.scatter(z_batch[n, 0], z_batch[n, 1], c=colors[y_label[n]], s=40, marker="o",
edgecolors='none')
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
recs = []
for i in range(0, len(colors)):
recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=colors[i]))
ax = pylab.subplot(111)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(recs, classes, loc="center left", bbox_to_anchor=(1.1, 0.5))
pylab.xticks(pylab.arange(-4, 5))
pylab.yticks(pylab.arange(-4, 5))
pylab.xlabel("z1")
pylab.ylabel("z2")
if visualization_dir is not None:
pylab.savefig("%s/10_2d-gaussian.png" % visualization_dir)
pylab.show()
utils.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录