def test_input_dropout_WITH_NON_ZERO_PROB(self):
rnn = EncoderRNN(self.vocab_size, 50, 16, input_dropout_p=0.5)
for param in rnn.parameters():
param.data.uniform_(-1, 1)
equal = True
for _ in range(50):
output1, _ = rnn(self.input_var, self.lengths)
output2, _ = rnn(self.input_var, self.lengths)
if not torch.equal(output1.data, output2.data):
equal = False
break
self.assertFalse(equal)
评论列表
文章目录