saliency.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:DeepLearning_PlantDiseases 作者: MarkoArsenovic 项目源码 文件源码
def Saliency_map(image,model,preprocess,ground_truth,use_gpu=False,method=util.GradType.GUIDED):
    vis_param_dict['method'] = method
    img_tensor = preprocess(image)
    img_tensor.unsqueeze_(0)
    if use_gpu:
        img_tensor=img_tensor.cuda()
    input = Variable(img_tensor,requires_grad=True)

    if  input.grad is not None:
        input.grad.data.zero_()

    model.zero_grad()
    output = model(input)
    ind=torch.LongTensor(1)
    if(isinstance(ground_truth,np.int64)):
        ground_truth=np.asscalar(ground_truth)
    ind[0]=ground_truth
    ind=Variable(ind)
    energy=output[0,ground_truth]
    energy.backward() 
    grad=input.grad
    if use_gpu:
        return np.abs(grad.data.cpu().numpy()[0]).max(axis=0)
    return np.abs(grad.data.numpy()[0]).max(axis=0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号