def test_lapsrn():
img=cv2.imread(("E:\\DevProj\\Datasets\\SuperResolution\\SR_testing_datasets"
"\\Set14\\GT\\zebra.png"),cv2.IMREAD_COLOR)
nh,nw,nc=img.shape
# imghr=cv2.cvtColor(img,cv2.COLOR_BGR2YCR_CB)
img_lr, img_pryd=img_preprocess(img,2)
one_batch=LapSRNDataBatch(img_lr,img_pryd)
net, arg_params, aux_params = mx.model.load_checkpoint("checkpoint\\lapsrn", 100)
mod = mx.mod.Module(symbol=net, context=mx.gpu())
provide_data=[('imglr', img_lr.shape)]
provide_label=[]
for s in range(2):
provide_label.append(("loss_s{}_imggt".format(s),img_pryd[s].shape))
mod.bind(for_training=False,
data_shapes=provide_data,
label_shapes=provide_label)
mod.set_params(arg_params, aux_params,allow_missing=True)
mod.forward(one_batch)
img_sr=mod.get_outputs()
# img_sr=img_recover(img_sr)
img_lr=img_recover(img_lr)
img_hr=img_recover(img_pryd[-1])
cv2.imwrite("results\\lapsrn_imglr.bmp",img_lr)
cv2.imwrite("results\\lapsrn_imghr.bmp",img_hr)
for s in range(2):
img_temp=img_recover(img_sr[s].asnumpy())
cv2.imwrite("results\\lapsrn_imgsr{}.bmp".format(s),img_temp)
评论列表
文章目录