def visualize_labeled_z(xp, model, x, y_label, visualization_dir, epoch, gpu=False):
x = chainer.Variable(xp.asarray(x))
z_batch = model.encode(x, test=True)
z_batch.to_cpu()
z_batch = z_batch.data
fig = pylab.gcf()
fig.set_size_inches(8.0, 8.0)
pylab.clf()
colors = ["#2103c8", "#0e960e", "#e40402", "#05aaa8", "#ac02ab", "#aba808", "#151515", "#94a169", "#bec9cd",
"#6a6551"]
for n in xrange(z_batch.shape[0]):
result = pylab.scatter(z_batch[n, 0], z_batch[n, 1], c=colors[y_label[n]], s=40, marker="o",
edgecolors='none')
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
recs = []
for i in range(0, len(colors)):
recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=colors[i]))
ax = pylab.subplot(111)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(recs, classes, loc="center left", bbox_to_anchor=(1.1, 0.5))
pylab.xticks(pylab.arange(-4, 5))
pylab.yticks(pylab.arange(-4, 5))
pylab.xlabel("z1")
pylab.ylabel("z2")
pylab.savefig("{}/labeled_z_{}.png".format(visualization_dir, epoch))
# pylab.show()
utils.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录