def interpolation(src_img, att_img, inter_num, model_dir, model, gpu):
'''
Input
src_img: the source image that you want to change its attribute
att_img: the attribute image that has certain attribute
inter_num: number of interpolation points
model_dir: the directory that contains the checkpoint, ckpt.* files
model: the GeneGAN network that defined in train.py
gpu: for example, '0,1'. Use '' for cpu mode
Output
out: [src_img, inter1, inter2, ..., inter_{inter_num}]
'''
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(model_dir)
# print(ckpt)
# print(ckpt.model_checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
out = src_img[0]
for i in range(1, inter_num + 1):
lambda_i = i / float(inter_num)
model.out_i = model.joiner('G_joiner', model.B, model.x * lambda_i)
out_i = sess.run(model.out_i, feed_dict={model.Ax: att_img, model.Be: src_img})
out = np.concatenate((out, out_i[0]), axis=1)
# print(out.shape)
misc.imsave('interpolation.jpg', out)
评论列表
文章目录