def test_forward_works_even_with_empty_sequences(self):
lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=11, batch_first=True)
encoder = PytorchSeq2VecWrapper(lstm)
tensor = torch.autograd.Variable(torch.rand([5, 7, 3]))
tensor[1, 6:, :] = 0
tensor[2, :, :] = 0
tensor[3, 2:, :] = 0
tensor[4, :, :] = 0
mask = torch.autograd.Variable(torch.ones(5, 7))
mask[1, 6:] = 0
mask[2, :] = 0
mask[3, 2:] = 0
mask[4, :] = 0
results = encoder(tensor, mask)
for i in (0, 1, 3):
assert not (results[i] == 0.).data.all()
for i in (2, 4):
assert (results[i] == 0.).data.all()
评论列表
文章目录