test_attention.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:OpenNMT-py 作者: OpenNMT 项目源码 文件源码
def test_masked_global_attention(self):
        source_lengths = torch.IntTensor([7, 3, 5, 2])
        illegal_weights_mask = torch.ByteTensor([
            [0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 1, 1, 1, 1],
            [0, 0, 0, 0, 0, 1, 1],
            [0, 0, 1, 1, 1, 1, 1]])

        batch_size = source_lengths.size(0)
        dim = 20

        context = Variable(torch.randn(batch_size, source_lengths.max(), dim))
        hidden = Variable(torch.randn(batch_size, dim))

        attn = onmt.modules.GlobalAttention(dim)

        _, alignments = attn(hidden, context, context_lengths=source_lengths)
        illegal_weights = alignments.masked_select(illegal_weights_mask)

        self.assertEqual(0.0, illegal_weights.data.sum())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号