def do_eval(sess,
eval_loss,
images_placeholder,
labels_placeholder,
training_time_placeholder,
images,
labels,
batch_size):
'''
Function for running the evaluations every X iterations on the training and validation sets.
:param sess: The current tf session
:param eval_loss: The placeholder containing the eval loss
:param images_placeholder: Placeholder for the images
:param labels_placeholder: Placeholder for the masks
:param training_time_placeholder: Placeholder toggling the training/testing mode.
:param images: A numpy array or h5py dataset containing the images
:param labels: A numpy array or h45py dataset containing the corresponding labels
:param batch_size: The batch_size to use.
:return: The average loss (as defined in the experiment), and the average dice over all `images`.
'''
loss_ii = 0
dice_ii = 0
num_batches = 0
for batch in BackgroundGenerator(iterate_minibatches(images, labels, batch_size=batch_size, augment_batch=False)): # No aug in evaluation
# As before you can wrap the iterate_minibatches function in the BackgroundGenerator class for speed improvements
# but at the risk of not catching exceptions
x, y = batch
if y.shape[0] < batch_size:
continue
feed_dict = { images_placeholder: x,
labels_placeholder: y,
training_time_placeholder: False}
closs, cdice = sess.run(eval_loss, feed_dict=feed_dict)
loss_ii += closs
dice_ii += cdice
num_batches += 1
avg_loss = loss_ii / num_batches
avg_dice = dice_ii / num_batches
logging.info(' Average loss: %0.04f, average dice: %0.04f' % (avg_loss, avg_dice))
return avg_loss, avg_dice
评论列表
文章目录