def test_weighted_sum_handles_uneven_higher_order_input(self):
batch_size = 1
length_1 = 5
length_2 = 6
length_3 = 2
embedding_dim = 4
sentence_array = numpy.random.rand(batch_size, length_3, embedding_dim)
attention_array = numpy.random.rand(batch_size, length_1, length_2, length_3)
sentence_tensor = Variable(torch.from_numpy(sentence_array).float())
attention_tensor = Variable(torch.from_numpy(attention_array).float())
aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
assert aggregated_array.shape == (batch_size, length_1, length_2, embedding_dim)
for i in range(length_1):
for j in range(length_2):
expected_array = (attention_array[0, i, j, 0] * sentence_array[0, 0] +
attention_array[0, i, j, 1] * sentence_array[0, 1])
numpy.testing.assert_almost_equal(aggregated_array[0, i, j], expected_array,
decimal=5)
评论列表
文章目录