train_wgan.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chainer-image-generation 作者: fukuta0614 项目源码 文件源码
def visualize(gen, epoch, savedir, batch_size=36, image_type='sigmoid'):
    z = chainer.Variable(gen.xp.asarray(gen.make_hidden(batch_size)), volatile=True)
    x_fake = gen(z, train=False)
    if image_type == 'sigmoid':
        img_gen = ((cuda.to_cpu(x_fake.data)) * 255).clip(0, 255).astype(np.uint8)
    else:
        img_gen = ((cuda.to_cpu(x_fake.data) + 1) * 127.5).clip(0, 255).astype(np.uint8)

    fig = plt.figure(figsize=(9, 9))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    for i in range(36):
        ax = fig.add_subplot(6, 6, i + 1, xticks=[], yticks=[])
        ax.imshow(img_gen[i].transpose(1, 2, 0))
    fig.savefig('{}/generate_{:03d}'.format(savedir, epoch))
    # plt.show()
    plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号