test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:GeneGAN 作者: Prinsphield 项目源码 文件源码
def interpolation(src_img, att_img, inter_num, model_dir, model, gpu):
    '''
    Input
        src_img: the source image that you want to change its attribute
        att_img: the attribute image that has certain attribute
        inter_num: number of interpolation points
        model_dir: the directory that contains the checkpoint, ckpt.* files
        model: the GeneGAN network that defined in train.py
        gpu: for example, '0,1'. Use '' for cpu mode
    Output
        out: [src_img, inter1, inter2, ..., inter_{inter_num}]
    '''
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(model_dir)
        # print(ckpt)
        # print(ckpt.model_checkpoint_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        out = src_img[0]
        for i in range(1, inter_num + 1):
            lambda_i = i / float(inter_num)
            model.out_i = model.joiner('G_joiner', model.B, model.x * lambda_i) 
            out_i = sess.run(model.out_i, feed_dict={model.Ax: att_img, model.Be: src_img})
            out = np.concatenate((out, out_i[0]), axis=1)
        # print(out.shape)
        misc.imsave('interpolation.jpg', out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号