def total_loss_sum(losses):
'''
Adds L2 regularization loss to the given list of losses
Parameters
----------
losses : list
List of losses
Returns
-------
total_loss: float
L2 regularized loss
'''
# Assemble all of the losses for the current tower only.
# Calculate the total loss for the current tower.
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_loss = tf.add_n(losses + regularization_losses, name='total_loss')
return total_loss
评论列表
文章目录