style_transfer.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ml_styles 作者: imironhead 项目源码 文件源码
def transfer():
    """
    """
    if tf.gfile.IsDirectory(FLAGS.ckpt_path):
        ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
    elif tf.gfile.Exists(FLAGS.ckpt_path):
        ckpt_source_path = FLAGS.ckpt_path
    else:
        assert False, 'bad checkpoint'

    assert tf.gfile.Exists(FLAGS.content_path), 'bad content_path'
    assert not tf.gfile.IsDirectory(FLAGS.content_path), 'bad content_path'

    image_contents = build_contents_reader()

    network = build_style_transfer_network(image_contents, training=False)

    #
    shape = tf.shape(network['image_styled'])

    new_w = shape[1] - 2 * FLAGS.padding
    new_h = shape[2] - 2 * FLAGS.padding

    image_styled = tf.slice(
        network['image_styled'],
        [0, FLAGS.padding, FLAGS.padding, 0],
        [-1, new_w, new_h, -1])

    image_styled = tf.squeeze(image_styled, [0])
    image_styled = image_styled * 127.5 + 127.5
    image_styled = tf.reverse(image_styled, [2])
    image_styled = tf.saturate_cast(image_styled, tf.uint8)
    image_styled = tf.image.encode_jpeg(image_styled)

    image_styled_writer = tf.write_file(FLAGS.styled_path, image_styled)

    with tf.Session() as session:
        tf.train.Saver().restore(session, ckpt_source_path)

        # make dataset reader work
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        session.run(image_styled_writer)

        coord.request_stop()
        coord.join(threads)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号