demo_lapsrn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:SuperResolutionCNN 作者: galad-loth 项目源码 文件源码
def test_lapsrn():
    img=cv2.imread(("E:\\DevProj\\Datasets\\SuperResolution\\SR_testing_datasets"
                    "\\Set14\\GT\\zebra.png"),cv2.IMREAD_COLOR)
    nh,nw,nc=img.shape
#    imghr=cv2.cvtColor(img,cv2.COLOR_BGR2YCR_CB)    
    img_lr, img_pryd=img_preprocess(img,2)
    one_batch=LapSRNDataBatch(img_lr,img_pryd)

    net, arg_params, aux_params = mx.model.load_checkpoint("checkpoint\\lapsrn", 100)
    mod = mx.mod.Module(symbol=net, context=mx.gpu())

    provide_data=[('imglr', img_lr.shape)]
    provide_label=[]
    for s in range(2):
        provide_label.append(("loss_s{}_imggt".format(s),img_pryd[s].shape))   
    mod.bind(for_training=False, 
             data_shapes=provide_data,
             label_shapes=provide_label)
    mod.set_params(arg_params, aux_params,allow_missing=True)  


    mod.forward(one_batch)
    img_sr=mod.get_outputs()

#    img_sr=img_recover(img_sr)
    img_lr=img_recover(img_lr)
    img_hr=img_recover(img_pryd[-1])
    cv2.imwrite("results\\lapsrn_imglr.bmp",img_lr)
    cv2.imwrite("results\\lapsrn_imghr.bmp",img_hr)
    for s in range(2):
        img_temp=img_recover(img_sr[s].asnumpy())
        cv2.imwrite("results\\lapsrn_imgsr{}.bmp".format(s),img_temp)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号