attention_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_no_mask(self):
        attention = Attention()

        # Testing general non-batched case.
        vector = Variable(torch.FloatTensor([[0.3, 0.1, 0.5]]))
        matrix = Variable(torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2]]]))

        result = attention(vector, matrix).data.numpy()
        assert_almost_equal(result, numpy.array([[0.52871835, 0.47128162]]))

        # Testing non-batched case where inputs are all 0s.
        vector = Variable(torch.FloatTensor([[0, 0, 0]]))
        matrix = Variable(torch.FloatTensor([[[0, 0, 0], [0, 0, 0]]]))

        result = attention(vector, matrix).data.numpy()
        assert_almost_equal(result, numpy.array([[0.5, 0.5]]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号