def test_add_sentence_boundary_token_ids_handles_3D_input(self):
tensor = Variable(torch.from_numpy(
numpy.array([[[1, 2, 3, 4],
[5, 5, 5, 5],
[6, 8, 1, 2]],
[[4, 3, 2, 1],
[8, 7, 6, 5],
[0, 0, 0, 0]]])))
mask = ((tensor > 0).sum(dim=-1) > 0).type(torch.LongTensor)
bos = Variable(torch.from_numpy(numpy.array([9, 9, 9, 9])))
eos = Variable(torch.from_numpy(numpy.array([10, 10, 10, 10])))
new_tensor, new_mask = util.add_sentence_boundary_token_ids(tensor, mask, bos, eos)
expected_new_tensor = numpy.array([[[9, 9, 9, 9],
[1, 2, 3, 4],
[5, 5, 5, 5],
[6, 8, 1, 2],
[10, 10, 10, 10]],
[[9, 9, 9, 9],
[4, 3, 2, 1],
[8, 7, 6, 5],
[10, 10, 10, 10],
[0, 0, 0, 0]]])
assert (new_tensor.data.numpy() == expected_new_tensor).all()
assert (new_mask.data.numpy() == ((expected_new_tensor > 0).sum(axis=-1) > 0)).all()
评论列表
文章目录