def test_batched_masked(self):
attention = Attention()
# Testing general masked non-batched case.
vector = Variable(torch.FloatTensor([[0.3, 0.1, 0.5], [0.3, 0.1, 0.5]]))
matrix = Variable(torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2], [0.5, 0.3, 0.2]],
[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2], [0.5, 0.3, 0.2]]]))
mask = Variable(torch.FloatTensor([[1.0, 1.0, 0.0], [1.0, 0.0, 1.0]]))
result = attention(vector, matrix, mask).data.numpy()
assert_almost_equal(result, numpy.array([[0.52871835, 0.47128162, 0.0],
[0.50749944, 0.0, 0.49250056]]))
# Test the case where a mask is all 0s and an input is all 0s.
vector = Variable(torch.FloatTensor([[0.0, 0.0, 0.0], [0.3, 0.1, 0.5]]))
matrix = Variable(torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2], [0.5, 0.3, 0.2]],
[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2], [0.5, 0.3, 0.2]]]))
mask = Variable(torch.FloatTensor([[1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]))
result = attention(vector, matrix, mask).data.numpy()
assert_almost_equal(result, numpy.array([[0.5, 0.5, 0.0],
[0.0, 0.0, 0.0]]))
评论列表
文章目录