def _loss_softmax(self, logits, labels, is_training, weighted=False):
log.info('Using softmax loss')
labels = tf.cast(labels, tf.int64)
if tf.rank(labels) != 2:
labels = tf.one_hot(labels, self.num_classes)
if weighted:
weights = self._compute_weights(labels)
weights = tf.reduce_max(tf.multiply(weights, labels), axis=1)
ce_loss = tf.losses.softmax_cross_entropy(
labels, logits=logits, weights=weights, label_smoothing=self.label_smoothing, scope='cross_entropy_loss')
else:
ce_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_loss')
ce_loss_mean = tf.reduce_mean(ce_loss, name='cross_entropy')
if is_training:
tf.add_to_collection('losses', ce_loss_mean)
l2_loss = tf.add_n(tf.get_collection(
tf.GraphKeys.REGULARIZATION_LOSSES))
l2_loss = l2_loss * self.cnf.get('l2_reg', 0.0)
tf.add_to_collection('losses', l2_loss)
return tf.add_n(tf.get_collection('losses'), name='total_loss')
else:
return ce_loss_mean
评论列表
文章目录