test_nn.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_pack_padded_sequence(self):
        def pad(tensor, length):
            return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])
        lengths = [10, 8, 4, 2, 2, 2, 1]
        max_length = lengths[0]
        batch_sizes = [sum(map(bool, filter(lambda x: x >= i, lengths))) for i in range(1, max_length + 1)]
        offset = 0
        padded = torch.cat([pad(i * 100 + torch.arange(1, 5 * l + 1).view(l, 1, 5), max_length)
                            for i, l in enumerate(lengths, 1)], 1)
        padded = Variable(padded, requires_grad=True)
        expected_data = [[torch.arange(1, 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
                         for n, batch_size in enumerate(batch_sizes)]
        expected_data = list(itertools.chain.from_iterable(expected_data))
        expected_data = torch.stack(expected_data, dim=0)

        for batch_first in (True, False):
            src = padded
            if batch_first:
                src = src.transpose(0, 1)

            # check output
            packed = rnn_utils.pack_padded_sequence(src, lengths, batch_first=batch_first)
            self.assertEqual(packed.data.data, expected_data)
            self.assertEqual(packed.batch_sizes, batch_sizes)

            # test inverse
            unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed, batch_first=batch_first)
            self.assertEqual(unpacked, src)
            self.assertEqual(unpacked_len, lengths)

            # check grad
            if padded.grad is not None:
                padded.grad.data.zero_()
            grad_output = unpacked.data.clone().normal_()
            unpacked.backward(grad_output)
            if batch_first:
                grad_output.transpose_(0, 1)
            for i, l in enumerate(lengths):
                self.assertEqual(padded.grad.data[:l, i], grad_output[:l, i])
                if l < 10:
                    self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号