def test_remove_sentence_boundaries(self):
tensor = Variable(torch.from_numpy(numpy.random.rand(3, 5, 7)))
mask = Variable(torch.from_numpy(
# The mask with two elements is to test the corner case
# of an empty sequence, so here we are removing boundaries
# from "<S> </S>"
numpy.array([[1, 1, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 0]]))).long()
new_tensor, new_mask = util.remove_sentence_boundaries(tensor, mask)
expected_new_tensor = Variable(torch.zeros(3, 3, 7))
expected_new_tensor[1, 0:3, :] = tensor[1, 1:4, :]
expected_new_tensor[2, 0:2, :] = tensor[2, 1:3, :]
assert_array_almost_equal(new_tensor.data.numpy(), expected_new_tensor.data.numpy())
expected_new_mask = Variable(torch.from_numpy(
numpy.array([[0, 0, 0],
[1, 1, 1],
[1, 1, 0]]))).long()
assert (new_mask.data.numpy() == expected_new_mask.data.numpy()).all()
评论列表
文章目录