test_nn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def _test_maxpool_indices(self, num_dim):
        def expected_indices(dim):
            if dim == 1:
                return torch.DoubleTensor([1, 3])
            lower_dim = expected_indices(dim-1)
            lower_dim = lower_dim.view(1, *lower_dim.size())
            return torch.cat((lower_dim+4, lower_dim+12), 0)

        def expected_grad(dim):
            if dim == 1:
                return torch.DoubleTensor([0, 1, 0, 1])
            lower_dim_grad = expected_grad(dim-1)
            grad = lower_dim_grad.view(1, *lower_dim_grad.size())
            zero = torch.zeros(grad.size())
            return torch.cat((zero, grad, zero, grad), 0)

        module_cls = getattr(nn, 'MaxPool{}d'.format(num_dim))
        module = module_cls(2, return_indices=True)
        numel = 4 ** num_dim
        input = torch.range(1, numel).view(1, 1, *repeat(4, num_dim))
        input_var = Variable(input, requires_grad=True)

        # Check forward
        output, indices = module(input_var)
        if num_dim != 3:
            expected_indices = expected_indices(num_dim)
            expected_output = expected_indices + 1
            self.assertEqual(indices.data.squeeze(), expected_indices)
            self.assertEqual(output.data.squeeze(), expected_output)
        self.assertTrue(output.requires_grad)
        self.assertFalse(indices.requires_grad)

        # Make sure backward works
        grad_output = torch.DoubleTensor(output.size()).fill_(1)
        output.backward(grad_output, retain_variables=True)
        expected_grad = expected_grad(num_dim)
        self.assertEqual(input_var.grad, expected_grad.view_as(input))

        # Make sure backward after changing indices will result in an error
        indices.add_(1)
        self.assertRaises(RuntimeError, lambda: output.backward(grad_output))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号