util_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_weighted_sum_handles_uneven_higher_order_input(self):
        batch_size = 1
        length_1 = 5
        length_2 = 6
        length_3 = 2
        embedding_dim = 4
        sentence_array = numpy.random.rand(batch_size, length_3, embedding_dim)
        attention_array = numpy.random.rand(batch_size, length_1, length_2, length_3)
        sentence_tensor = Variable(torch.from_numpy(sentence_array).float())
        attention_tensor = Variable(torch.from_numpy(attention_array).float())
        aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
        assert aggregated_array.shape == (batch_size, length_1, length_2, embedding_dim)
        for i in range(length_1):
            for j in range(length_2):
                expected_array = (attention_array[0, i, j, 0] * sentence_array[0, 0] +
                                  attention_array[0, i, j, 1] * sentence_array[0, 1])
                numpy.testing.assert_almost_equal(aggregated_array[0, i, j], expected_array,
                                                  decimal=5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号