word2vec.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:w2vec-similarity 作者: jayantj 项目源码 文件源码
def train(self):
    """Train the model."""
    opts = self._options
    initial_epoch, initial_words = self._session.run([self._epoch, self._words])
    summary_op = tf.merge_all_summaries()
    summary_writer = tf.train.SummaryWriter(opts.save_path,
                                            graph_def=self._session.graph_def)
    workers = []
    for _ in xrange(opts.concurrent_steps):
      t = threading.Thread(target=self._train_thread_body)
      t.start()
      workers.append(t)
    last_words, last_time, last_summary_time = initial_words, time.time(), 0
    last_checkpoint_time = 0
    while True:
      time.sleep(opts.statistics_interval)  # Reports our progress once a while.
      (epoch, step, loss, words, lr) = self._session.run(
          [self._epoch, self.global_step, self._loss, self._words, self._lr])
      now = time.time()
      last_words, last_time, rate = words, now, (words - last_words) / (
          now - last_time)
      print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" %
            (epoch, step, lr, loss, rate), end="")
      sys.stdout.flush()
      if now - last_summary_time > opts.summary_interval:
        summary_str = self._session.run(summary_op)
        summary_writer.add_summary(summary_str, step)
        last_summary_time = now
      if now - last_checkpoint_time > opts.checkpoint_interval:
        self.saver.save(self._session,
                        opts.save_path + "model",
                        global_step=step.astype(int))
        last_checkpoint_time = now
      if epoch != initial_epoch:
        break
    for t in workers:
      t.join()
    return epoch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号