def clf_loss_oneclass(pred_logits, gt_labels, cls_num):
"""Compute classification loss for oneclass problem.
Args:
pred_logits: logits prediction from a model.
gt_labels: ground truth class labels.
cls_num: number of classes.
Returns:
computed loss.
"""
with tf.variable_scope("clf_loss"):
tf.assert_equal(tf.reduce_max(gt_labels), tf.convert_to_tensor(cls_num))
onehot_labels = tf.one_hot(gt_labels, cls_num)
clf_loss_elem = tf.losses.softmax_cross_entropy(onehot_labels, pred_logits)
mean_loss = tf.reduce_mean(clf_loss_elem, 0)
return mean_loss
评论列表
文章目录