def test_no_mask(self):
attention = Attention()
# Testing general non-batched case.
vector = Variable(torch.FloatTensor([[0.3, 0.1, 0.5]]))
matrix = Variable(torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2]]]))
result = attention(vector, matrix).data.numpy()
assert_almost_equal(result, numpy.array([[0.52871835, 0.47128162]]))
# Testing non-batched case where inputs are all 0s.
vector = Variable(torch.FloatTensor([[0, 0, 0]]))
matrix = Variable(torch.FloatTensor([[[0, 0, 0], [0, 0, 0]]]))
result = attention(vector, matrix).data.numpy()
assert_almost_equal(result, numpy.array([[0.5, 0.5]]))
评论列表
文章目录