pytorch_seq2seq_wrapper_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_wrapper_works_when_passed_state_with_zero_length_sequences(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)
        tensor = torch.rand([5, 7, 3])
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 0:] = 0
        mask[3, 6:] = 0

        # Initial states are of shape (num_layers * num_directions, batch_size, hidden_dim)
        initial_states = (Variable(torch.randn(6, 5, 7)),
                          Variable(torch.randn(6, 5, 7)))

        input_tensor = Variable(tensor)
        mask = Variable(mask)
        _ = encoder(input_tensor, mask, initial_states)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号