def test_wrapper_stateful_single_state_gru(self):
gru = GRU(bidirectional=True, num_layers=2, input_size=3, hidden_size=7, batch_first=True)
encoder = PytorchSeq2SeqWrapper(gru, stateful=True)
batch_sizes = [10, 5]
states = []
for batch_size in batch_sizes:
tensor = Variable(torch.rand([batch_size, 5, 3]))
mask = Variable(torch.ones(batch_size, 5))
mask.data[0, 3:] = 0
encoder_output = encoder(tensor, mask)
states.append(encoder._states) # pylint: disable=protected-access
assert_almost_equal(encoder_output[0, 3:, :].data.numpy(), numpy.zeros((2, 14)))
assert_almost_equal(
states[-1][0][:, -5:, :].data.numpy(), states[-2][0][:, -5:, :].data.numpy()
)
评论列表
文章目录