lstm_attention.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch-seq2seq 作者: rowanz 项目源码 文件源码
def _sample(self, state, context, mask, max_len=20):
        """
        Performs sampling
        """
        batch_size = state.size(0)

        toks = [const_row(self.bos_token, batch_size, volatile=True)]

        lens = torch.IntTensor(batch_size)
        if torch.cuda.is_available():
            lens = lens.cuda()

        for l in range(max_len + 1):  # +1 because of EOS
            out, state, alpha = self._lstm_loop(state, self.embedding(toks[-1]), context, mask)

            # Do argmax (since we're doing greedy decoding)
            toks.append(out.max(1)[1].squeeze(1))

            lens[(toks[-1].data == self.eos_token) & (lens == 0)] = l+1
            if all(lens):
                break
        lens[lens == 0] = max_len+1
        return torch.stack(toks, 0), lens
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号