nasm.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:variational-text-tensorflow 作者: carpedm20 项目源码 文件源码
def train(self, config):
    start_time = time.time()

    merged_sum = tf.merge_all_summaries()
    writer = tf.train.SummaryWriter("./logs", self.sess.graph_def)

    tf.initialize_all_variables().run()
    self.load(self.checkpoint_dir)

    for epoch in range(self.epoch):
      epoch_loss = 0.

      for idx, x in enumerate(self.reader.next_batch()):
        _, loss, e_loss, g_loss, summary_str = self.sess.run(
            [self.optim, self.loss, self.e_loss, self.g_loss, merged_sum], feed_dict={self.x: x})

        epoch_loss += loss
        if idx % 10 == 0:
          print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, e_loss: %.8f, g_loss: %.8f" \
              % (epoch, idx, self.reader.batch_cnt, time.time() - start_time, loss, e_loss, g_loss))

        if idx % 2 == 0:
          writer.add_summary(summary_str, step)

        if idx != 0 and idx % 1000 == 0:
          self.save(self.checkpoint_dir, step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号