def compute_loss(logits, y, lens):
batch_size, seq_len, vocab_size = logits.size()
logits = logits.view(batch_size * seq_len, vocab_size)
y = y.view(-1)
logprobs = F.log_softmax(logits)
losses = -torch.gather(logprobs, 1, y.unsqueeze(-1))
losses = losses.view(batch_size, seq_len)
mask = sequence_mask(lens, seq_len).float()
losses = losses * mask
loss_batch = losses.sum() / len(lens)
loss_step = losses.sum() / lens.sum().float()
return loss_batch, loss_step
评论列表
文章目录