def main(_):
"""Train FlowNet"""
with tf.Graph().as_default():
# get data
imgs_0, imgs_1, flows = flownet_tools.get_data(FLAGS.datadir, True)
# img summary after loading
#flownet.image_summary(imgs_0, imgs_1, "A_input", flows)
# apply augmentation
imgs_0, imgs_1, flows = apply_augmentation(imgs_0, imgs_1, flows)
# model
calc_flows = model(imgs_0, imgs_1, flows)
# img summary of result
flownet.image_summary(None, None, "E_result", calc_flows)
# global step and other config
global_step = slim.get_or_create_global_step()
train_op = flownet.create_train_op(global_step)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)
# start slim training
slim.learning.train(
train_op,
logdir=FLAGS.logdir + '/train',
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
summary_op=tf.summary.merge_all(),
log_every_n_steps=FLAGS.log_every_n_steps,
trace_every_n_steps=FLAGS.trace_every_n_steps,
session_config=config,
saver=saver,
number_of_steps=FLAGS.max_steps,
)
评论列表
文章目录