losses.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:deepmodels 作者: learningsociety 项目源码 文件源码
def clf_loss_oneclass(pred_logits, gt_labels, cls_num):
  """Compute classification loss for oneclass problem.

  Args:
    pred_logits: logits prediction from a model.
    gt_labels: ground truth class labels.
    cls_num: number of classes.
  Returns:
    computed loss.
  """
  with tf.variable_scope("clf_loss"):
    tf.assert_equal(tf.reduce_max(gt_labels), tf.convert_to_tensor(cls_num))
    onehot_labels = tf.one_hot(gt_labels, cls_num)
    clf_loss_elem = tf.losses.softmax_cross_entropy(onehot_labels, pred_logits)
    mean_loss = tf.reduce_mean(clf_loss_elem, 0)
  return mean_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号