cyclegan.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ml_gans 作者: imironhead 项目源码 文件源码
def translate():
    """
    """
    image_path_pairs = prepare_paths()

    reals = tf.placeholder(shape=[None, 256, 256, 3], dtype=tf.uint8)

    flow = tf.cast(reals, dtype=tf.float32) / 127.5 - 1.0

    model = build_cycle_gan(flow, flow, FLAGS.mode)

    fakes = tf.saturate_cast(model['fake'] * 127.5 + 127.5, tf.uint8)

    # path to checkpoint
    ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.local_variables_initializer())

        tf.train.Saver().restore(session, ckpt_source_path)

        for i in range(0, len(image_path_pairs), FLAGS.batch_size):
            path_pairs = image_path_pairs[i:i+FLAGS.batch_size]

            real_images = [scipy.misc.imread(p[0]) for p in path_pairs]

            fake_images = session.run(fakes, feed_dict={reals: real_images})

            for idx, path in enumerate(path_pairs):
                image = np.concatenate(
                    [real_images[idx], fake_images[idx]], axis=1)

                scipy.misc.imsave(path[1], image)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号