def clf_loss_multiclass(pred_logits, gt_labels, cls_num):
"""Compute classification loss for multi-class problem.
Args:
pred_logits: logits prediction from a model.
gt_labels: ground truth class labels [batch_size, num_cls] with (0,1) value.
cls_num: number of classes.
Returns:
computed loss.
"""
with tf.variable_scope("clf_loss"):
tf.assert_equal(tf.reduce_max(gt_labels), 1)
clf_loss_elem = tf.losses.sigmoid_cross_entropy(gt_labels, pred_logits)
mean_loss = tf.reduce_mean(clf_loss_elem, 0)
return mean_loss
评论列表
文章目录