def train(self):
"""Train the model."""
opts = self._options
initial_epoch, initial_words = self._session.run([self._epoch, self._words])
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(opts.save_path,
graph_def=self._session.graph_def)
workers = []
for _ in xrange(opts.concurrent_steps):
t = threading.Thread(target=self._train_thread_body)
t.start()
workers.append(t)
last_words, last_time, last_summary_time = initial_words, time.time(), 0
last_checkpoint_time = 0
while True:
time.sleep(opts.statistics_interval) # Reports our progress once a while.
(epoch, step, loss, words, lr) = self._session.run(
[self._epoch, self.global_step, self._loss, self._words, self._lr])
now = time.time()
last_words, last_time, rate = words, now, (words - last_words) / (
now - last_time)
print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" %
(epoch, step, lr, loss, rate), end="")
sys.stdout.flush()
if now - last_summary_time > opts.summary_interval:
summary_str = self._session.run(summary_op)
summary_writer.add_summary(summary_str, step)
last_summary_time = now
if now - last_checkpoint_time > opts.checkpoint_interval:
self.saver.save(self._session,
opts.save_path + "model",
global_step=step.astype(int))
last_checkpoint_time = now
if epoch != initial_epoch:
break
for t in workers:
t.join()
return epoch
评论列表
文章目录