generate.py 文件源码

python
阅读 64 收藏 0 点赞 0 评论 0

项目:IllustrationGAN 作者: tdrussell 项目源码 文件源码
def main(argv=None):
    input.init_dataset_constants()
    num_images = GRID[0] * GRID[1]
    FLAGS.batch_size = num_images
    with tf.Graph().as_default():
        g_template = model.generator_template()
        z = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.z_size])
        #np.random.seed(1337) # generate same random numbers each time
        noise = np.random.normal(size=(FLAGS.batch_size, FLAGS.z_size))
        with pt.defaults_scope(phase=pt.Phase.test):
            gen_images_op, _ = pt.construct_all(g_template, input=z)

        sess = tf.Session()
        init_variables(sess)
        gen_images, = sess.run([gen_images_op], feed_dict={z: noise})
        gen_images = (gen_images + 1) / 2

        sess.close()

        fig = plt.figure(1)
        grid = ImageGrid(fig, 111,
                         nrows_ncols=GRID,
                         axes_pad=0.1)
        for i in xrange(num_images):
            im = gen_images[i]
            axis = grid[i]
            axis.axis('off')
            axis.imshow(im)

        plt.show()
        fig.savefig('montage.png', dpi=100, bbox_inches='tight')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号