gradients.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:pytorch-smoothgrad 作者: pkdn 项目源码 文件源码
def __call__(self, x, index=None):
        output = self.pretrained_model(x)

        if index is None:
            index = np.argmax(output.data.cpu().numpy())

        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0][index] = 1
        if self.cuda:
            one_hot = Variable(torch.from_numpy(one_hot).cuda(), requires_grad=True)
        else:
            one_hot = Variable(torch.from_numpy(one_hot), requires_grad=True)
        one_hot = torch.sum(one_hot * output)

        one_hot.backward(retain_variables=True)

        grad = x.grad.data.cpu().numpy()
        grad = grad[0, :, :, :]

        return grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号