pytorch_seq2seq_wrapper_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_wrapper_stateful_single_state_gru(self):
        gru = GRU(bidirectional=True, num_layers=2, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2SeqWrapper(gru, stateful=True)

        batch_sizes = [10, 5]
        states = []
        for batch_size in batch_sizes:
            tensor = Variable(torch.rand([batch_size, 5, 3]))
            mask = Variable(torch.ones(batch_size, 5))
            mask.data[0, 3:] = 0
            encoder_output = encoder(tensor, mask)
            states.append(encoder._states)   # pylint: disable=protected-access

        assert_almost_equal(encoder_output[0, 3:, :].data.numpy(), numpy.zeros((2, 14)))
        assert_almost_equal(
                states[-1][0][:, -5:, :].data.numpy(), states[-2][0][:, -5:, :].data.numpy()
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号