def loss(self, logits, labels):
"""Add L2Loss to all the trainable variables.
Args:
logits: Logits from get().
labels: Labels from train_inputs or inputs(). 1-D tensor
of shape [batch_size]
Returns:
Loss tensor of type float.
"""
with tf.variable_scope('loss'):
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(
cross_entropy, name='cross_entropy')
tf.add_to_collection(LOSSES, cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
error = tf.add_n(tf.get_collection(LOSSES), name='total_loss')
return error
评论列表
文章目录