lossfun.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:alpha-dimt-icmlws 作者: sotetsuk 项目源码 文件源码
def memoryEfficientLoss(outputs, targets, generator, crit, max_generator_batches, eval=False):
    """Memory efficient loss.

    :param outputs: seq_len x batch_size x logits_size
    :param targets: seq_len x batch_size
    :param generator:
    :param crit:
    :param max_generator_batches:
    :param eval:
    :return:
    """
    # compute generations one piece at a time
    num_correct, loss = 0, 0
    outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)  # seq_len x batch_size x logits_size

    batch_size = outputs.size(1)
    outputs_split = torch.split(outputs, max_generator_batches)
    targets_split = torch.split(targets, max_generator_batches)
    for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
        # out_t = seq_len x batch_size x logits_size
        # targ_t = seq_len x batch_size

        out_t = out_t.view(-1, out_t.size(2))  # seq_len * batch_size x logits_size
        scores_t = generator(out_t)  # seq_len * batch_size x voc_size

        loss_t = crit(scores_t, targ_t.view(-1))  # scholar (1-d)

        pred_t = scores_t.max(1)[1]  # seq_len * batch_size x 1

        num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum()
        num_correct += num_correct_t
        loss += loss_t.data[0]
        if not eval:
            loss_t.div(batch_size).backward()

    grad_output = None if outputs.grad is None else outputs.grad.data
    return loss, grad_output, num_correct
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号