def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if self.model.validation_data and self.histogram_freq:
if epoch % self.histogram_freq == 0:
# TODO: implement batched calls to sess.run
# (current call will likely go OOM on GPU)
if self.model.uses_learning_phase:
cut_v_data = len(self.model.inputs)
val_data = self.model.validation_data[:cut_v_data] + [0]
tensors = self.model.inputs + [K.learning_phase()]
else:
val_data = self.model.validation_data
tensors = self.model.inputs
feed_dict = dict(zip(tensors, val_data))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
self.writer.add_summary(summary, epoch)
self.writer.flush()
评论列表
文章目录