def loss(logits, labels):
# Reshape the labels into a dense Tensor of shape [batch_size, NUM_CLASSES].
sparse_labels = tf.reshape(labels, [input.FLAGS.batch_size, 1])
indices = tf.reshape(tf.range(0, input.FLAGS.batch_size), [input.FLAGS.batch_size, 1])
concated = tf.concat(1, [indices, sparse_labels])
dense_labels = tf.sparse_to_dense(concated, [input.FLAGS.batch_size, input.NUM_CLASSES], 1.0, 0.0)
# Calculate the average cross entropy loss across the batch.
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, dense_labels, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录