predict.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ResNet 作者: tornadomeet 项目源码 文件源码
def main():
    synset = [l.strip() for l in open(args.synset).readlines()]
    img = cv2.cvtColor(cv2.imread(args.img), cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224, 224))  # resize to 224*224 to fit model
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)  # change to (c, h,w) order
    img = img[np.newaxis, :]  # extend to (n, c, h, w)

    ctx = mx.gpu(args.gpu)
    sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
    arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
    arg_params["data"] = mx.nd.array(img, ctx)
    arg_params["softmax_label"] = mx.nd.empty((1,), ctx)
    exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
    exe.forward(is_train=False)

    prob = np.squeeze(exe.outputs[0].asnumpy())
    pred = np.argsort(prob)[::-1]
    print("Top1 result is: ", synset[pred[0]])
    print("Top5 result is: ", [synset[pred[i]] for i in range(5)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号