model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码
def compute_loss(logits, y, lens):
    batch_size, seq_len, vocab_size = logits.size()
    logits = logits.view(batch_size * seq_len, vocab_size)
    y = y.view(-1)

    logprobs = F.log_softmax(logits)
    losses = -torch.gather(logprobs, 1, y.unsqueeze(-1))
    losses = losses.view(batch_size, seq_len)
    mask = sequence_mask(lens, seq_len).float()
    losses = losses * mask
    loss_batch = losses.sum() / len(lens)
    loss_step = losses.sum() / lens.sum().float()

    return loss_batch, loss_step
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号