mock_client.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tfserving_predict_client 作者: epigramai 项目源码 文件源码
def predict(self, request_data, request_timeout=10):

        logger.info('Sending request to tfserving model')
        logger.info('Model name: ' + str(self.model_name))
        logger.info('Model version: ' + str(self.model_version))
        logger.info('Host: ' + str(self.host))

        tensor_shape = request_data.shape

        if self.model_name == 'incv4' or self.model_name == 'res152':
            features_tensor_proto = tf.contrib.util.make_tensor_proto(request_data, shape=tensor_shape)
        else:
            features_tensor_proto = tf.contrib.util.make_tensor_proto(request_data,
                                                                      dtype=tf.float32, shape=tensor_shape)

        # Create gRPC client and request
        channel = grpc.insecure_channel(self.host)
        stub = PredictionServiceStub(channel)
        request = PredictRequest()

        request.model_spec.name = self.model_name

        if self.model_version > 0:
            request.model_spec.version.value = self.model_version

        request.inputs['inputs'].CopyFrom(features_tensor_proto)

        try:
            result = stub.Predict(request, timeout=request_timeout)
            logger.debug('Predicted scores with len: ' + str(len(list(result.outputs['scores'].float_val))))
            return list(result.outputs['scores'].float_val)
        except RpcError as e:
            logger.warning(e)
            logger.warning('Prediction failed. Mock client will return empty prediction of length: '
                           + str(self.num_scores))
            return [0] * self.num_scores
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号