def test_input_dropout_WITH_NON_ZERO_PROB(self):
rnn = DecoderRNN(self.vocab_size, 50, 16, 0, 1, input_dropout_p=0.5)
for param in rnn.parameters():
param.data.uniform_(-1, 1)
equal = True
for _ in range(50):
output1, _, _ = rnn()
output2, _, _ = rnn()
if not torch.equal(output1[0].data, output2[0].data):
equal = False
break
self.assertFalse(equal)
评论列表
文章目录