def train(self, config):
start_time = time.time()
merged_sum = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("./logs", self.sess.graph_def)
tf.initialize_all_variables().run()
self.load(self.checkpoint_dir)
for epoch in range(self.epoch):
epoch_loss = 0.
for idx, x in enumerate(self.reader.next_batch()):
_, loss, e_loss, g_loss, summary_str = self.sess.run(
[self.optim, self.loss, self.e_loss, self.g_loss, merged_sum], feed_dict={self.x: x})
epoch_loss += loss
if idx % 10 == 0:
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, e_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.reader.batch_cnt, time.time() - start_time, loss, e_loss, g_loss))
if idx % 2 == 0:
writer.add_summary(summary_str, step)
if idx != 0 and idx % 1000 == 0:
self.save(self.checkpoint_dir, step)
评论列表
文章目录