def loss(logits, label_batch):
"""
Add L2Loss to all the trainable variables.
Add summary for "Loss" and "Loss/avg".
Args:
logits -> logits from inference()
label_batch -> 1D tensor of [batch_size]
Rtns:
total_loss -> float tensor
"""
# Calculate the average cross entropy loss across the batch.
label_batch = tf.cast(label_batch,tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits,
label_batch,name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses',cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录