model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensorflow_mnist_cloudml 作者: mainyaa 项目源码 文件源码
def build_prediction_graph(self, export_dir):
    """Builds prediction graph and registers appropriate endpoints."""
    logging.info('Exporting prediction graph to %s', export_dir)
    examples = tf.placeholder(tf.string, shape=(None,))
    features = {
        'image': tf.FixedLenFeature(
            shape=[IMAGE_PIXELS], dtype=tf.float32),
        'key': tf.FixedLenFeature(
            shape=[], dtype=tf.string),
    }

    parsed = tf.parse_example(examples, features)
    images = parsed['image']
    keys = parsed['key']

    # Build a Graph that computes predictions from the inference model.
    logits = inference(images, self.hidden1, self.hidden2)
    softmax = tf.nn.softmax(logits)
    prediction = tf.argmax(softmax, 1)

    # Mark the inputs and the outputs
    # Marking the input tensor with an alias with suffix _bytes. This is to
    # indicate that this tensor value is raw bytes and will be base64 encoded
    # over HTTP.
    # Note that any output tensor marked with an alias with suffix _bytes, shall
    # be base64 encoded in the HTTP response. To get the binary value, it
    # should be base64 decoded.
    tf.add_to_collection('inputs',
                         json.dumps({'examples_bytes': examples.name}))
    tf.add_to_collection('outputs', json.dumps({
        'key': keys.name,
        'prediction': prediction.name,
        'scores': softmax.name
    }))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号