def test_wrapper_stateful(self):
lstm = LSTM(bidirectional=True, num_layers=2, input_size=3, hidden_size=7, batch_first=True)
encoder = PytorchSeq2SeqWrapper(lstm, stateful=True)
# To test the stateful functionality we need to call the encoder multiple times.
# Different batch sizes further tests some of the logic.
batch_sizes = [5, 10, 8]
sequence_lengths = [4, 6, 7]
states = []
for batch_size, sequence_length in zip(batch_sizes, sequence_lengths):
tensor = Variable(torch.rand([batch_size, sequence_length, 3]))
mask = Variable(torch.ones(batch_size, sequence_length))
mask.data[0, 3:] = 0
encoder_output = encoder(tensor, mask)
states.append(encoder._states) # pylint: disable=protected-access
# Check that the output is masked properly.
assert_almost_equal(encoder_output[0, 3:, :].data.numpy(), numpy.zeros((4, 14)))
for k in range(2):
assert_almost_equal(
states[-1][k][:, -2:, :].data.numpy(), states[-2][k][:, -2:, :].data.numpy()
)
评论列表
文章目录