lm.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def loss(self, batch_data, test=False):
        # unpack data
        (source, targets), conds = batch_data, None
        if self.conds is not None:
            (source, *conds), (targets, *_) = source, targets

        # get hidden from previous batch (if stored)
        hidden = self.hidden_state.get('hidden', None)
        # run RNN
        output, hidden, _ = self(source, hidden=hidden, conds=conds)
        # store hidden for next batch
        self.hidden_state['hidden'] = u.repackage_hidden(hidden)

        # compute loss and backward
        loss = F.nll_loss(output, targets.view(-1), size_average=True)

        if not test:
            loss.backward()

        return (loss.data[0], ), source.nelement()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号