def compute_loss(self, decoder_output, _features, labels):
"""Computes the loss for this model.
Returns a tuple `(losses, loss)`, where `losses` are the per-batch
losses and loss is a single scalar tensor to minimize.
"""
#pylint: disable=R0201
# Calculate loss per example-timestep of shape [B, T]
losses = seq2seq_losses.cross_entropy_sequence_loss(
logits=decoder_output.logits[:, :, :],
targets=tf.transpose(labels["target_ids"][:, 1:], [1, 0]),
sequence_length=labels["target_len"] - 1)
# Calculate the average log perplexity
loss = tf.reduce_sum(losses) / tf.to_float(
tf.reduce_sum(labels["target_len"] - 1))
return losses, loss
评论列表
文章目录