train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Bayesian-FlowNet 作者: Johswald 项目源码 文件源码
def main(_):
    """Train FlowNet"""

    with tf.Graph().as_default():
        # get data
        imgs_0, imgs_1, flows = flownet_tools.get_data(FLAGS.datadir, True)

        # img summary after loading
        #flownet.image_summary(imgs_0, imgs_1, "A_input", flows)

        # apply augmentation
        imgs_0, imgs_1, flows = apply_augmentation(imgs_0, imgs_1, flows)

        # model
        calc_flows = model(imgs_0, imgs_1, flows)

        # img summary of result
        flownet.image_summary(None, None, "E_result", calc_flows)

        # global step and other config
        global_step = slim.get_or_create_global_step()
        train_op = flownet.create_train_op(global_step)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                               keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

        # start slim training
        slim.learning.train(
            train_op,
            logdir=FLAGS.logdir + '/train',
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            summary_op=tf.summary.merge_all(),
            log_every_n_steps=FLAGS.log_every_n_steps,
            trace_every_n_steps=FLAGS.trace_every_n_steps,
            session_config=config,
            saver=saver,
            number_of_steps=FLAGS.max_steps,
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号