EncoderDecoder.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:bandit-nmt 作者: khanhptnk 项目源码 文件源码
def translate(self, inputs, max_length):
        targets, init_states = self.initialize(inputs, eval=True)
        emb, output, hidden, context = init_states

        preds = [] 
        batch_size = targets.size(1)
        num_eos = targets[0].data.byte().new(batch_size).zero_()

        for i in range(max_length):
            output, hidden = self.decoder.step(emb, output, hidden, context)
            logit = self.generator(output)
            pred = logit.max(1)[1].view(-1).data
            preds.append(pred)

            # Stop if all sentences reach EOS.
            num_eos |= (pred == lib.Constants.EOS)
            if num_eos.sum() == batch_size: break

            emb = self.decoder.word_lut(Variable(pred))

        preds = torch.stack(preds)
        return preds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号