local_predict_tests.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pydatalab 作者: googledatalab 项目源码 文件源码
def _create_model(self, dir_name):
    """Create a simple model that takes 'key', 'num1', 'text1', 'img_url1' input."""

    def _decode_jpg(image):
      img_buf = BytesIO()
      Image.new('RGB', (16, 16)).save(img_buf, 'jpeg')
      default_image_string = base64.urlsafe_b64encode(img_buf.getvalue())
      image = tf.where(tf.equal(image, ''), default_image_string, image)
      image = tf.decode_base64(image)
      image = tf.image.decode_jpeg(image, channels=3)
      image = tf.reshape(image, [-1])
      image = tf.reduce_max(image)
      return image

    model_dir = tempfile.mkdtemp()
    with tf.Session(graph=tf.Graph()) as sess:
      record_defaults = [
          tf.constant([0], dtype=tf.int64),
          tf.constant([0.0], dtype=tf.float32),
          tf.constant([''], dtype=tf.string),
          tf.constant([''], dtype=tf.string),
      ]
      placeholder = tf.placeholder(dtype=tf.string, shape=(None,), name='csv_input_placeholder')
      key_tensor, num_tensor, text_tensor, img_tensor = tf.decode_csv(placeholder, record_defaults)
      text_tensor = tf.string_to_number(text_tensor, tf.float32)
      img_tensor = tf.map_fn(_decode_jpg, img_tensor, back_prop=False, dtype=tf.uint8)
      img_tensor = tf.cast(img_tensor, tf.float32)
      stacked = tf.stack([num_tensor, text_tensor, img_tensor])
      min_tensor = tf.reduce_min(stacked, axis=0)
      max_tensor = tf.reduce_max(stacked, axis=0)

      predict_input_tensor = tf.saved_model.utils.build_tensor_info(placeholder)
      predict_signature_inputs = {"input": predict_input_tensor}
      predict_output_tensor1 = tf.saved_model.utils.build_tensor_info(min_tensor)
      predict_output_tensor2 = tf.saved_model.utils.build_tensor_info(max_tensor)
      predict_key_tensor = tf.saved_model.utils.build_tensor_info(key_tensor)
      predict_signature_outputs = {
        'key': predict_key_tensor,
        'var1': predict_output_tensor1,
        'var2': predict_output_tensor2
      }
      predict_signature_def = (
          tf.saved_model.signature_def_utils.build_signature_def(
              predict_signature_inputs, predict_signature_outputs,
              tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          )
      )
      signature_def_map = {
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: predict_signature_def
      }
      model_dir = os.path.join(self._test_dir, dir_name)
      builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
      builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING],
          signature_def_map=signature_def_map)
      builder.save(False)

    return model_dir
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号