trainer.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:a3c-tensorflow 作者: carpedm20 项目源码 文件源码
def train(self):
    variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()

    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    tf.logging.info('Trainable vars:')
    slim.model_analyzer.analyze_vars(var_list, print_info=True)

    def init_fn(ses):
      tf.logging.info("="*30)
      tf.logging.info("Initializing all parameters.")
      tf.logging.info("="*30)
      ses.run(init_all_op)

    sess_config = tf.ConfigProto(
      device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)])

    summary_writer = tf.summary.FileWriter("{}_{}".format(self.log_dir, self.task))
    tf.logging.info("Events directory: %s_%s", self.log_dir, self.task)
    sv = tf.train.Supervisor(is_chief=(self.task == 0),
                             logdir=self.log_dir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             save_model_secs=600,
                             save_summaries_secs=30)

    num_global_steps = 100000000

    with sv.managed_session(self.server.target, config=sess_config) as sess, sess.as_default():
      sess.run(self.agent.sync)
      self.agent.start(sess, summary_writer)

      global_step = sess.run(self.agent.global_step)
      tf.logging.info("Starting training at step=%d", global_step)

      while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
        self.agent.process(sess)
        global_step = sess.run(self.agent.global_step)

    # Ask for all the services to stop.
    sv.stop()
    tf.logging.info('reached %s steps. worker stopped.', global_step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号