def _evaluate_on_validation(self, get_val_instance_generator,
batch_size,
num_val_steps,
session):
val_batch_gen = DataManager.get_batch_generator(
get_val_instance_generator, batch_size)
# Calculate the mean of the validation metrics
# over the validation set.
val_accuracies = []
val_losses = []
for val_batch in tqdm(val_batch_gen,
total=num_val_steps,
desc="Validation Batches Completed",
leave=False):
feed_dict = self._get_validation_feed_dict(val_batch)
val_batch_acc, val_batch_loss = session.run(
[self.accuracy, self.loss],
feed_dict=feed_dict)
val_accuracies.append(val_batch_acc)
val_losses.append(val_batch_loss)
# Take the mean of the accuracies and losses.
# TODO/FIXME this assumes each batch is same shape, which
# is not necessarily true.
mean_val_accuracy = np.mean(val_accuracies)
mean_val_loss = np.mean(val_losses)
# Create a new Summary object with mean_val accuracy
# and mean_val_loss and add it to Tensorboard.
val_summary = tf.Summary(value=[
tf.Summary.Value(tag="val_summaries/loss",
simple_value=mean_val_loss),
tf.Summary.Value(tag="val_summaries/accuracy",
simple_value=mean_val_accuracy)])
return mean_val_accuracy, mean_val_loss, val_summary
base_tf_model.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录