def inference():
cv2.namedWindow('frame', cv2.WINDOW_NORMAL)
#cv2.namedWindow('pred', cv2.WINDOW_NORMAL)
# load model
model = Netmodel('eval-model', CLASSES)
serializers.load_npz(MODEL_NAME, model)
cuda.get_device(GPU_ID).use()
model.to_gpu()
LUT = fromHEX2RGB(stats_opts['colormap'] )
fig3, axarr3 = plt.subplots(1, 1)
batchRGB = np.zeros((1, 3, NEWSIZE[1], NEWSIZE[0]), dtype='float32')
# go throught the data
flist = []
with open(TESTFILE) as f:
for line in f:
cline = re.split('\n',line)
#print(cline[0])
frame = misc.imread(cline[0])
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# process frame
im = misc.imresize(frame, NEWSIZE, interp='bilinear')
# convertion from HxWxCH to CHxWxH
batchRGB[0,:,:,:] = im.astype(np.float32).transpose((2,1,0))
batchRGBn = batchRGB - 127.0
# data ready
batch = chainer.Variable(cuda.cupy.asarray(batchRGBn))
# make predictions
model((batch, []), test_mode=2)
pred = model.probs.data.argmax(1)
# move data back to CPU
pred_ = cuda.to_cpu(pred)
pred_ = LUT[pred_+1,:].squeeze()
pred_ = pred_.transpose((1,0,2))
pred2 = cv2.cvtColor(pred_, cv2.COLOR_BGR2RGB)
#ipdb.set_trace()
disp = (0.4*im + 0.6*pred2).astype(np.uint8)
# Display the resulting frame
cv2.imshow('frame',disp)
#cv2.imshow('pred',pred2)
if cv2.waitKey(-1) & 0xFF == ord('q'):
break
cv2.destroyAllWindows()
评论列表
文章目录