test_topkdecoder.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch-seq2seq 作者: IBM 项目源码 文件源码
def test_k_1(self):
        """ When k=1, the output of topk decoder should be the same as a normal decoder. """
        batch_size = 1
        eos = 1

        for _ in range(10):
            # Repeat the randomized test multiple times
            decoder = DecoderRNN(self.vocab_size, 50, 16, 0, eos)
            for param in decoder.parameters():
                param.data.uniform_(-1, 1)
            topk_decoder = TopKDecoder(decoder, 1)

            output, _, other = decoder()
            output_topk, _, other_topk = topk_decoder()

            self.assertEqual(len(output), len(output_topk))

            finished = [False] * batch_size
            seq_scores = [0] * batch_size

            for t_step, t_output in enumerate(output):
                score, _ = t_output.topk(1)
                symbols = other['sequence'][t_step]
                for b in range(batch_size):
                    seq_scores[b] += score[b].data[0]
                    symbol = symbols[b].data[0]
                    if not finished[b] and symbol == eos:
                        finished[b] = True
                        self.assertEqual(other_topk['length'][b], t_step + 1)
                        self.assertTrue(np.isclose(seq_scores[b], other_topk['score'][b][0]))
                    if not finished[b]:
                        symbol_topk = other_topk['topk_sequence'][t_step][b].data[0][0]
                        self.assertEqual(symbol, symbol_topk)
                        self.assertTrue(torch.equal(t_output.data, output_topk[t_step].data))
                if sum(finished) == batch_size:
                    break
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号