def run_epoch(self, session, train_op, train_writer, batch_gen=None, num_iterations=NUM_ITERATIONS, output_dir="output", write_image=False):
epoch_size = num_iterations
start_time = time.time()
image_skip = 1 if epoch_size < 5 else epoch_size / 5
summary_skip = 1 if epoch_size < 25 else epoch_size / 25
for step in range(epoch_size):
if self.model_name == MULTISCALE:
feed = self.add_noise_to_feed({})
else:
feed = {}
batch = batch_gen.get_batch()
feed[self.image] = batch
if self.is_training:
ops = [train_op, self.loss, self.merged, self.image_summary, self.input_summary, self.generator.out, self.global_step]
_, loss, summary, image_summary, input_summary, last_out, global_step = session.run(ops, feed_dict=feed)
if write_image and step % image_skip == 0:
utils.write_image(os.path.join('%s/images/valid_%d.png' % (output_dir, step)), last_out)
if train_writer != None:
if step % summary_skip == 0:
train_writer.add_summary(summary, global_step)
train_writer.flush()
if step % image_skip == 0:
train_writer.add_summary(input_summary)
train_writer.flush()
train_writer.add_summary(image_summary)
train_writer.flush()
else:
ops = self.generator.out
last_out = session.run(ops, feed_dict=feed)
loss = summary = image_summary = input_summary = global_step = None
return loss, summary, image_summary, last_out, global_step
评论列表
文章目录