infer_file.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def inference():
    cv2.namedWindow('frame', cv2.WINDOW_NORMAL)
    #cv2.namedWindow('pred', cv2.WINDOW_NORMAL)


    # load model
    model = Netmodel('eval-model', CLASSES)
    serializers.load_npz(MODEL_NAME, model)
    cuda.get_device(GPU_ID).use()
    model.to_gpu()

    LUT = fromHEX2RGB(stats_opts['colormap'] )
    fig3, axarr3 = plt.subplots(1, 1)

    batchRGB = np.zeros((1, 3, NEWSIZE[1], NEWSIZE[0]), dtype='float32')

    # go throught the data
    flist = []
    with open(TESTFILE) as f:
        for line in f:
            cline = re.split('\n',line)

            #print(cline[0])
            frame = misc.imread(cline[0])
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            # process frame
            im = misc.imresize(frame, NEWSIZE, interp='bilinear')
            # convertion from HxWxCH to CHxWxH
            batchRGB[0,:,:,:] = im.astype(np.float32).transpose((2,1,0))
            batchRGBn = batchRGB  - 127.0

            # data ready
            batch = chainer.Variable(cuda.cupy.asarray(batchRGBn))

            # make predictions
            model((batch, []), test_mode=2)
            pred = model.probs.data.argmax(1)
            # move data back to CPU
            pred_ = cuda.to_cpu(pred)

            pred_ = LUT[pred_+1,:].squeeze()
            pred_ = pred_.transpose((1,0,2))
            pred2 = cv2.cvtColor(pred_, cv2.COLOR_BGR2RGB)

            #ipdb.set_trace()
            disp = (0.4*im + 0.6*pred2).astype(np.uint8)

                # Display the resulting frame
            cv2.imshow('frame',disp)
            #cv2.imshow('pred',pred2)
                if cv2.waitKey(-1) & 0xFF == ord('q'):
                    break
    cv2.destroyAllWindows()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号