test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:superres 作者: ntomita 项目源码 文件源码
def test(argv=sys.argv[1:]):
    input = "../dataset/BSDS300/images/val/54082.jpg"
    #input = "../dataset/BSDS300/images/val/159008.jpg"
    output = "sr_{}".format(basename(input))  # save in cwd
    output2 = "sr__{}".format(basename(input))
    model = "snapshot/gnet-epoch-1-pretrain.pth"
    #model = "snapshot/gnet-epoch-200.pth"
    cuda = True
    img = Image.open(input)
    width, height = img.size

    gennet = torch.load(model)
    img = ToTensor()(img)  # [c,w,h]->[1,c,h,w]
    input = Variable(img).view(1, 3, height, width)

    if cuda:
        gennet = gennet.cuda()
        input = input.cuda()

    pred = gennet(input).cpu()
    save_image(pred.data, output)
    #ToPILImage()(pred.data).save(output)

    toImage(pred).save(output2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号