def cross_entropy_sequence_loss(logits, targets, sequence_length):
"""Calculates the per-example cross-entropy loss for a sequence of logits and
masks out all losses passed the sequence length.
Args:
logits: Logits of shape `[T, B, vocab_size]`
targets: Target classes of shape `[T, B]`
sequence_length: An int32 tensor of shape `[B]` corresponding
to the length of each input
Returns:
A tensor of shape [T, B] that contains the loss per example, per time step.
"""
with tf.name_scope("cross_entropy_sequence_loss"):
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=targets)
# Mask out the losses we don't care about
loss_mask = tf.sequence_mask(
tf.to_int32(sequence_length), tf.to_int32(tf.shape(targets)[0]))
losses = losses * tf.transpose(tf.to_float(loss_mask), [1, 0])
return losses
评论列表
文章目录