def test_wrapper_works_when_passed_state_with_zero_length_sequences(self):
lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
encoder = PytorchSeq2SeqWrapper(lstm)
tensor = torch.rand([5, 7, 3])
mask = torch.ones(5, 7)
mask[0, 3:] = 0
mask[1, 4:] = 0
mask[2, 0:] = 0
mask[3, 6:] = 0
# Initial states are of shape (num_layers * num_directions, batch_size, hidden_dim)
initial_states = (Variable(torch.randn(6, 5, 7)),
Variable(torch.randn(6, 5, 7)))
input_tensor = Variable(tensor)
mask = Variable(mask)
_ = encoder(input_tensor, mask, initial_states)
评论列表
文章目录