def main():
synset = [l.strip() for l in open(args.synset).readlines()]
img = cv2.cvtColor(cv2.imread(args.img), cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224)) # resize to 224*224 to fit model
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2) # change to (c, h,w) order
img = img[np.newaxis, :] # extend to (n, c, h, w)
ctx = mx.gpu(args.gpu)
sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
arg_params["data"] = mx.nd.array(img, ctx)
arg_params["softmax_label"] = mx.nd.empty((1,), ctx)
exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
exe.forward(is_train=False)
prob = np.squeeze(exe.outputs[0].asnumpy())
pred = np.argsort(prob)[::-1]
print("Top1 result is: ", synset[pred[0]])
print("Top5 result is: ", [synset[pred[i]] for i in range(5)])
评论列表
文章目录