def main(unused_argv):
train_dir = FLAGS.train_dir
if not os.path.exists(train_dir):
tf.logging.info("Creating training directory: %s", train_dir)
os.makedirs(train_dir)
g = tf.Graph()
with g.as_default():
model = create_model(FLAGS)
model.build()
learning_rate, learning_rate_decay_fn = learning_rate_fn(model.config.batch_size, FLAGS.num_epochs_per_decay)
train_op = tf.contrib.layers.optimize_loss(
loss=model.total_loss,
global_step=model.global_step,
learning_rate=learning_rate,
learning_rate_decay_fn=learning_rate_decay_fn,
optimizer=FLAGS.optimizer)
saver = tf.train.Saver(max_to_keep=FLAGS.max_checkpoints_to_keep)
tf.contrib.slim.learning.train(
train_op,
train_dir,
log_every_n_steps=FLAGS.log_every_n_steps,
graph=g,
global_step=model.global_step,
number_of_steps=FLAGS.number_of_steps,
save_interval_secs=FLAGS.save_interval_secs,
save_summaries_secs=FLAGS.save_summaries_secs,
saver=saver)
评论列表
文章目录