def test_forward_pulls_out_correct_tensor_with_sequence_lengths(self):
lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
encoder = PytorchSeq2SeqWrapper(lstm)
tensor = torch.rand([5, 7, 3])
tensor[1, 6:, :] = 0
tensor[2, 4:, :] = 0
tensor[3, 2:, :] = 0
tensor[4, 1:, :] = 0
mask = torch.ones(5, 7)
mask[1, 6:] = 0
mask[2, 4:] = 0
mask[3, 2:] = 0
mask[4, 1:] = 0
input_tensor = Variable(tensor)
mask = Variable(mask)
sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
packed_sequence = pack_padded_sequence(input_tensor, sequence_lengths.data.tolist(), batch_first=True)
lstm_output, _ = lstm(packed_sequence)
encoder_output = encoder(input_tensor, mask)
lstm_tensor, _ = pad_packed_sequence(lstm_output, batch_first=True)
assert_almost_equal(encoder_output.data.numpy(), lstm_tensor.data.numpy())
评论列表
文章目录