pytorch_seq2seq_wrapper_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_forward_works_even_with_empty_sequences(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)

        tensor = torch.autograd.Variable(torch.rand([5, 7, 3]))
        tensor[1, 6:, :] = 0
        tensor[2, :, :] = 0
        tensor[3, 2:, :] = 0
        tensor[4, :, :] = 0
        mask = torch.autograd.Variable(torch.ones(5, 7))
        mask[1, 6:] = 0
        mask[2, :] = 0
        mask[3, 2:] = 0
        mask[4, :] = 0

        results = encoder(tensor, mask)

        for i in (0, 1, 3):
            assert not (results[i] == 0.).data.all()
        for i in (2, 4):
            assert (results[i] == 0.).data.all()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号