demo.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:pytorch_crowd_count 作者: BingzheWu 项目源码 文件源码
def demo(img_path):
    net = predict_net()
    net.load_state_dict(torch.load('checkpoint/crowd_net2.pth'))
    input_img = read_gray_img(img_path)
    input_img = torch.autograd.Variable(torch.Tensor(input_img/255.0))
    print(input_img.size())
    #input_image = input_image.view(1, 3, 255, 255)
    heat_map = net.forward(input_img)
    print heat_map.size()
    heat_map = torch.squeeze(heat_map)
    heat_map = heat_map.data.numpy()
    plt.imshow(heat_map, cmap = 'hot')
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号