cyclegan.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:ml_gans 作者: imironhead 项目源码 文件源码
def train():
    """
    """
    ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)

    xx_real = build_image_batch_reader(
        FLAGS.x_images_dir_path, FLAGS.batch_size)

    yy_real = build_image_batch_reader(
        FLAGS.y_images_dir_path, FLAGS.batch_size)

    image_pool = {}

    model = build_cycle_gan(xx_real, yy_real, '')

    summaries = build_summaries(model)

    reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.local_variables_initializer())

        if ckpt_source_path is not None:
            tf.train.Saver().restore(session, ckpt_source_path)

        # give up overlapped old data
        step = session.run(model['step'])

        reporter.add_session_log(
            tf.SessionLog(status=tf.SessionLog.START),
            global_step=step)

        # make dataset reader work
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        while train_one_step(model, summaries, image_pool, reporter):
            pass

        coord.request_stop()
        coord.join(threads)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号