test_encoder_rnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch-seq2seq 作者: IBM 项目源码 文件源码
def test_input_dropout_WITH_NON_ZERO_PROB(self):
        rnn = EncoderRNN(self.vocab_size, 50, 16, input_dropout_p=0.5)
        for param in rnn.parameters():
            param.data.uniform_(-1, 1)

        equal = True
        for _ in range(50):
            output1, _ = rnn(self.input_var, self.lengths)
            output2, _ = rnn(self.input_var, self.lengths)
            if not torch.equal(output1.data, output2.data):
                equal = False
                break
        self.assertFalse(equal)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号