def alpha_loss(outputs, targets, generator, crit, max_generator_batches, rewards, proposed_weights, tau, alpha, eval=False):
"""Loss function of proposed method.
:param outputs: seq_len x batch_size x logits_size
:param targets: seq_len x batch_size
:param generator:
:param crit:
:param max_generator_batches:
:param eval:
:return:
"""
# compute generations one piece at a time
num_correct, loss = 0, 0
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) # seq_len x batch_size x logits_size
batch_size = outputs.size(1)
outputs_split = torch.split(outputs, max_generator_batches)
targets_split = torch.split(targets, max_generator_batches)
# TODO(sotetsuk): fix to calculate at once
importance_list = []
p_sample_efficiency_list = []
q_sample_efficiency_list = []
pq_sample_efficiency_list = []
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
out_t = out_t.view(-1, out_t.size(2)) # seq_len * batch_size x logits_size
scores_t = generator(out_t) # seq_len * batch_size x voc_size
proposed_weights = torch.FloatTensor(proposed_weights)
log_q_weights = torch.FloatTensor(rewards) / tau
loss_t, importance_t, p_sample_efficiency_t, q_sample_efficiency_t, pq_sample_efficiency_t = crit(scores_t, targ_t.view(-1), proposed_weights, log_q_weights, alpha, rewards) # scholar (1-d)
pred_t = scores_t.max(1)[1] # seq_len * batch_size x 1
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum()
num_correct += num_correct_t
loss += loss_t.data[0]
importance_list += importance_t
p_sample_efficiency_list += p_sample_efficiency_t
q_sample_efficiency_list += q_sample_efficiency_t
pq_sample_efficiency_list += pq_sample_efficiency_t
if not eval:
loss_t.div(batch_size).backward()
grad_output = None if outputs.grad is None else outputs.grad.data
return loss, grad_output, num_correct, importance_list, p_sample_efficiency_list, q_sample_efficiency_list, pq_sample_efficiency_list
评论列表
文章目录