def classification_costs(logits, labels, name=None):
"""Compute classification cost mean and classification cost per sample
Assume unlabeled examples have label == -1. For unlabeled examples, cost == 0.
Compute the mean over all examples.
Note that unlabeled examples are treated differently in error calculation.
"""
with tf.name_scope(name, "classification_costs") as scope:
applicable = tf.not_equal(labels, -1)
# Change -1s to zeros to make cross-entropy computable
labels = tf.where(applicable, labels, tf.zeros_like(labels))
# This will now have incorrect values for unlabeled examples
per_sample = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
# Retain costs only for labeled
per_sample = tf.where(applicable, per_sample, tf.zeros_like(per_sample))
# Take mean over all examples, not just labeled examples.
labeled_sum = tf.reduce_sum(per_sample)
total_count = tf.to_float(tf.shape(per_sample)[0])
mean = tf.div(labeled_sum, total_count, name=scope)
return mean, per_sample
评论列表
文章目录