def get_classification_loss(logits, targets, softmax_loss_function=None):
bucket_outputs = logits
if softmax_loss_function is None:
assert len(bucket_outputs) == len(targets) == 1
# We need to make target an int64-tensor and set its shape.
bucket_target = array_ops.reshape(math_ops.to_int64(targets[0]), [-1])
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(bucket_outputs[0], bucket_target)
else:
assert len(bucket_outputs) == len(targets) == 1
crossent = softmax_loss_function(bucket_outputs[0], targets[0])
batch_size = array_ops.shape(targets[0])[0]
loss = tf.reduce_sum(crossent) / math_ops.cast(batch_size, dtypes.float32)
return loss
评论列表
文章目录