TensorFlow:如何从SavedModel进行预测?

发布于 2021-01-29 15:15:45

我已经导出了SavedModel,现在我可以将其加载回并进行预测。经过培训,具有以下功能和标签:

F1 : FLOAT32
F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32

所以说我要输入的值得20.9, 1.8, 0.9到一个FLOAT32预测。我该怎么做?我已经成功地加载了模型,但是我不确定如何访问它以进行预测调用。

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    # How can I predict from here?
    # I want to do something like prediction = model.predict([20.9, 1.8, 0.9])

该问题不是此处发布的问题的重复。这个问题集中于在SavedModel任何模型类(不仅仅限于tf.estimator)上进行推理的最小示例,以及指定输入和输出节点名称的语法。

关注者
0
被浏览
47
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    加载图形后,它就可以在当前上下文中使用,您可以通过它馈入输入数据以获得预测。每个用例都有很大的不同,但是在代码中添加的内容如下所示:

    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(
            sess,
            [tf.saved_model.tag_constants.SERVING],
            "/job/export/Servo/1503723455"
        )
    
        prediction = sess.run(
            'prefix/predictions/Identity:0',
            feed_dict={
                'Placeholder:0': [20.9],
                'Placeholder_1:0': [1.8],
                'Placeholder_2:0': [0.9]
            }
        )
    
        print(prediction)
    

    在这里,您需要知道预测输入的名称。如果您没有给他们带来天真serving_fn,则它们默认为Placeholder_n,这n是第n个功能。

    的第一个字符串参数sess.run是预测目标的名称。这将根据您的用例而有所不同。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看