EncoderDecoder.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:bandit-nmt 作者: khanhptnk 项目源码 文件源码
def sample(self, inputs, max_length):
        targets, init_states = self.initialize(inputs, eval=False)
        emb, output, hidden, context = init_states

        outputs = []
        samples = []
        batch_size = targets.size(1)
        num_eos = targets[0].data.byte().new(batch_size).zero_()

        for i in range(max_length):
            output, hidden = self.decoder.step(emb, output, hidden, context)
            outputs.append(output)
            dist = F.softmax(self.generator(output))
            sample = dist.multinomial(1, replacement=False).view(-1).data
            samples.append(sample)

            # Stop if all sentences reach EOS.
            num_eos |= (sample == lib.Constants.EOS)
            if num_eos.sum() == batch_size: break

            emb = self.decoder.word_lut(Variable(sample))

        outputs = torch.stack(outputs)
        samples = torch.stack(samples)
        return samples, outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号