train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:num-seq-recognizer 作者: gmlove 项目源码 文件源码
def main(unused_argv):
  train_dir = FLAGS.train_dir
  if not os.path.exists(train_dir):
    tf.logging.info("Creating training directory: %s", train_dir)
    os.makedirs(train_dir)

  g = tf.Graph()
  with g.as_default():
    model = create_model(FLAGS)
    model.build()

    learning_rate, learning_rate_decay_fn = learning_rate_fn(model.config.batch_size, FLAGS.num_epochs_per_decay)

    train_op = tf.contrib.layers.optimize_loss(
      loss=model.total_loss,
      global_step=model.global_step,
      learning_rate=learning_rate,
      learning_rate_decay_fn=learning_rate_decay_fn,
      optimizer=FLAGS.optimizer)

    saver = tf.train.Saver(max_to_keep=FLAGS.max_checkpoints_to_keep)

  tf.contrib.slim.learning.train(
    train_op,
    train_dir,
    log_every_n_steps=FLAGS.log_every_n_steps,
    graph=g,
    global_step=model.global_step,
    number_of_steps=FLAGS.number_of_steps,
    save_interval_secs=FLAGS.save_interval_secs,
    save_summaries_secs=FLAGS.save_summaries_secs,
    saver=saver)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号