def test(argv=sys.argv[1:]):
input = "../dataset/BSDS300/images/val/54082.jpg"
#input = "../dataset/BSDS300/images/val/159008.jpg"
output = "sr_{}".format(basename(input)) # save in cwd
output2 = "sr__{}".format(basename(input))
model = "snapshot/gnet-epoch-1-pretrain.pth"
#model = "snapshot/gnet-epoch-200.pth"
cuda = True
img = Image.open(input)
width, height = img.size
gennet = torch.load(model)
img = ToTensor()(img) # [c,w,h]->[1,c,h,w]
input = Variable(img).view(1, 3, height, width)
if cuda:
gennet = gennet.cuda()
input = input.cuda()
pred = gennet(input).cpu()
save_image(pred.data, output)
#ToPILImage()(pred.data).save(output)
toImage(pred).save(output2)
评论列表
文章目录