utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chainer-adversarial-autoencoder 作者: fukuta0614 项目源码 文件源码
def visualize_labeled_z(xp, model, x, y_label, visualization_dir, epoch, gpu=False):
    x = chainer.Variable(xp.asarray(x))
    z_batch = model.encode(x, test=True)
    z_batch.to_cpu()
    z_batch = z_batch.data
    fig = pylab.gcf()
    fig.set_size_inches(8.0, 8.0)
    pylab.clf()
    colors = ["#2103c8", "#0e960e", "#e40402", "#05aaa8", "#ac02ab", "#aba808", "#151515", "#94a169", "#bec9cd",
              "#6a6551"]
    for n in xrange(z_batch.shape[0]):
        result = pylab.scatter(z_batch[n, 0], z_batch[n, 1], c=colors[y_label[n]], s=40, marker="o",
                               edgecolors='none')

    classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    recs = []
    for i in range(0, len(colors)):
        recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=colors[i]))

    ax = pylab.subplot(111)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.legend(recs, classes, loc="center left", bbox_to_anchor=(1.1, 0.5))
    pylab.xticks(pylab.arange(-4, 5))
    pylab.yticks(pylab.arange(-4, 5))
    pylab.xlabel("z1")
    pylab.ylabel("z2")
    pylab.savefig("{}/labeled_z_{}.png".format(visualization_dir, epoch))
    # pylab.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号