pytorch_seq2seq_wrapper_test.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_wrapper_stateful(self):
        lstm = LSTM(bidirectional=True, num_layers=2, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm, stateful=True)

        # To test the stateful functionality we need to call the encoder multiple times.
        # Different batch sizes further tests some of the logic.
        batch_sizes = [5, 10, 8]
        sequence_lengths = [4, 6, 7]
        states = []
        for batch_size, sequence_length in zip(batch_sizes, sequence_lengths):
            tensor = Variable(torch.rand([batch_size, sequence_length, 3]))
            mask = Variable(torch.ones(batch_size, sequence_length))
            mask.data[0, 3:] = 0
            encoder_output = encoder(tensor, mask)
            states.append(encoder._states)  # pylint: disable=protected-access

        # Check that the output is masked properly.
        assert_almost_equal(encoder_output[0, 3:, :].data.numpy(), numpy.zeros((4, 14)))

        for k in range(2):
            assert_almost_equal(
                    states[-1][k][:, -2:, :].data.numpy(), states[-2][k][:, -2:, :].data.numpy()
            )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号