pytorch.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:foolbox 作者: bethgelab 项目源码 文件源码
def backward(self, gradient, image):
        # lazy import
        import torch
        from torch.autograd import Variable

        assert gradient.ndim == 1

        gradient = torch.from_numpy(gradient)
        if self.cuda:  # pragma: no cover
            gradient = gradient.cuda()
        gradient = Variable(gradient)

        image = self._process_input(image)
        assert image.ndim == 3
        images = image[np.newaxis]
        images = torch.from_numpy(images)
        if self.cuda:  # pragma: no cover
            images = images.cuda()
        images = Variable(images, requires_grad=True)
        predictions = self._model(images)

        print(predictions.size())
        predictions = predictions[0]

        assert gradient.dim() == 1
        assert predictions.dim() == 1
        assert gradient.size() == predictions.size()

        loss = torch.dot(predictions, gradient)
        loss.backward()
        # should be the same as predictions.backward(gradient=gradient)

        grad = images.grad

        grad = grad.data
        if self.cuda:  # pragma: no cover
            grad = grad.cpu()
        grad = grad.numpy()
        grad = self._process_gradient(grad)
        grad = np.squeeze(grad, axis=0)
        assert grad.shape == image.shape

        return grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号