python类save_params()的实例源码

mnist_pixelvae_train.py 文件源码 项目:PixelVAE 作者: igul222 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def generate_and_save_samples(tag):

    lib.save_params(os.path.join(OUT_DIR, tag + "_params.pkl"))

    def save_images(images, filename, i = None):
        """images.shape: (batch, n channels, height, width)"""
        if i is not None:
            new_tag = "{}_{}".format(tag, i)
        else:
            new_tag = tag

        images = images.reshape((10,10,28,28))

        images = images.transpose(1,2,0,3)
        images = images.reshape((10*28, 10*28))

        image = scipy.misc.toimage(images, cmin=0.0, cmax=1.0)
        image.save('{}/{}_{}.jpg'.format(OUT_DIR, filename, new_tag))

    latents = np.random.normal(size=(100, LATENT_DIM))

    latents = latents.astype(theano.config.floatX)

    samples = np.zeros(
        (100, N_CHANNELS, HEIGHT, WIDTH),
        dtype=theano.config.floatX
    )

    next_sample = samples.copy()

    t0 = time.time()
    for j in xrange(HEIGHT):
        for k in xrange(WIDTH):
            for i in xrange(N_CHANNELS):
                samples_p_value = sample_fn(latents, next_sample)
                next_sample[:, i, j, k] = binarize(samples_p_value)[:, i, j, k]
                samples[:, i, j, k] = samples_p_value[:, i, j, k]

    t1 = time.time()
    print("Time taken for generation {:.4f}".format(t1 - t0))

    save_images(samples_p_value, 'samples')


问题


面经


文章

微信
公众号

扫码关注公众号