def memoryEfficientLoss(outputs, targets, generator, crit, max_generator_batches, eval=False):
"""Memory efficient loss.
:param outputs: seq_len x batch_size x logits_size
:param targets: seq_len x batch_size
:param generator:
:param crit:
:param max_generator_batches:
:param eval:
:return:
"""
# compute generations one piece at a time
num_correct, loss = 0, 0
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) # seq_len x batch_size x logits_size
batch_size = outputs.size(1)
outputs_split = torch.split(outputs, max_generator_batches)
targets_split = torch.split(targets, max_generator_batches)
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
# out_t = seq_len x batch_size x logits_size
# targ_t = seq_len x batch_size
out_t = out_t.view(-1, out_t.size(2)) # seq_len * batch_size x logits_size
scores_t = generator(out_t) # seq_len * batch_size x voc_size
loss_t = crit(scores_t, targ_t.view(-1)) # scholar (1-d)
pred_t = scores_t.max(1)[1] # seq_len * batch_size x 1
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum()
num_correct += num_correct_t
loss += loss_t.data[0]
if not eval:
loss_t.div(batch_size).backward()
grad_output = None if outputs.grad is None else outputs.grad.data
return loss, grad_output, num_correct
评论列表
文章目录