Loss.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:OpenNMT-py 作者: OpenNMT 项目源码 文件源码
def compute_loss(self, batch, output, target):
        """ See base class for args description. """
        scores = self.generator(self.bottle(output))

        gtruth = target.view(-1)
        if self.confidence < 1:
            tdata = gtruth.data
            mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze()
            likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1))
            tmp_ = self.one_hot.repeat(gtruth.size(0), 1)
            tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
            if mask.dim() > 0:
                likelihood.index_fill_(0, mask, 0)
                tmp_.index_fill_(0, mask, 0)
            gtruth = Variable(tmp_, requires_grad=False)

        loss = self.criterion(scores, gtruth)
        if self.confidence < 1:
            loss_data = - likelihood.sum(0)
        else:
            loss_data = loss.data.clone()

        stats = self.stats(loss_data, scores.data, target.view(-1).data)

        return loss, stats
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号