def cross_entropy_sequence_loss(logits, targets, sequence_length):
with tf.name_scope('cross_entropy_sequence_loss'):
total_length = tf.to_float(tf.reduce_sum(sequence_length))
entropy_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=targets)
# Mask out the losses we don't care about
loss_mask = tf.sequence_mask(
tf.to_int32(sequence_length), tf.to_int32(tf.shape(targets)[0]))
loss_mask = tf.transpose(tf.to_float(loss_mask), [1, 0])
losses = entropy_losses * loss_mask
# losses.shape: T * B
# sequence_length: B
total_loss_avg = tf.reduce_sum(losses) / total_length
return total_loss_avg
评论列表
文章目录