def test_weighted_sum_handles_higher_order_input(self):
batch_size = 1
length_1 = 5
length_2 = 6
length_3 = 2
embedding_dim = 4
sentence_array = numpy.random.rand(batch_size, length_1, length_2, length_3, embedding_dim)
attention_array = numpy.random.rand(batch_size, length_1, length_2, length_3)
sentence_tensor = Variable(torch.from_numpy(sentence_array).float())
attention_tensor = Variable(torch.from_numpy(attention_array).float())
aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
assert aggregated_array.shape == (batch_size, length_1, length_2, embedding_dim)
expected_array = (attention_array[0, 3, 2, 0] * sentence_array[0, 3, 2, 0] +
attention_array[0, 3, 2, 1] * sentence_array[0, 3, 2, 1])
numpy.testing.assert_almost_equal(aggregated_array[0, 3, 2], expected_array, decimal=5)
评论列表
文章目录