def interpolation2(src_img, att_img, inter_num, model_dir, model, gpu):
'''
Input
src_img: the source image that you want to change its attribute
att_img: the attribute image that has certain attribute
inter_num: number of interpolation points
model_dir: the directory that contains the checkpoint, ckpt.* files
model: the GeneGAN network that defined in train.py
gpu: for example, '0,1'. Use '' for cpu mode
Output
out: [src_img, inter1, inter2, ..., inter_{inter_num}]
'''
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(model_dir)
# print(ckpt)
# print(ckpt.model_checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
B, src_feat = sess.run([model.B, model.e], feed_dict={model.Be: src_img})
att_feat = sess.run(model.x, feed_dict={model.Ax: att_img})
out = src_img[0]
for i in range(1, inter_num + 1):
lambda_i = i / float(inter_num)
out_i = sess.run(model.joiner('G_joiner', B, src_feat + (att_feat - src_feat) * lambda_i) )
out = np.concatenate((out, out_i[0]), axis=1)
# print(out.shape)
misc.imsave('interpolation2.jpg', out)
评论列表
文章目录