def test_k_1(self):
""" When k=1, the output of topk decoder should be the same as a normal decoder. """
batch_size = 1
eos = 1
for _ in range(10):
# Repeat the randomized test multiple times
decoder = DecoderRNN(self.vocab_size, 50, 16, 0, eos)
for param in decoder.parameters():
param.data.uniform_(-1, 1)
topk_decoder = TopKDecoder(decoder, 1)
output, _, other = decoder()
output_topk, _, other_topk = topk_decoder()
self.assertEqual(len(output), len(output_topk))
finished = [False] * batch_size
seq_scores = [0] * batch_size
for t_step, t_output in enumerate(output):
score, _ = t_output.topk(1)
symbols = other['sequence'][t_step]
for b in range(batch_size):
seq_scores[b] += score[b].data[0]
symbol = symbols[b].data[0]
if not finished[b] and symbol == eos:
finished[b] = True
self.assertEqual(other_topk['length'][b], t_step + 1)
self.assertTrue(np.isclose(seq_scores[b], other_topk['score'][b][0]))
if not finished[b]:
symbol_topk = other_topk['topk_sequence'][t_step][b].data[0][0]
self.assertEqual(symbol, symbol_topk)
self.assertTrue(torch.equal(t_output.data, output_topk[t_step].data))
if sum(finished) == batch_size:
break
评论列表
文章目录