def _create_model(self, dir_name):
"""Create a simple model that takes 'key', 'num1', 'text1', 'img_url1' input."""
def _decode_jpg(image):
img_buf = BytesIO()
Image.new('RGB', (16, 16)).save(img_buf, 'jpeg')
default_image_string = base64.urlsafe_b64encode(img_buf.getvalue())
image = tf.where(tf.equal(image, ''), default_image_string, image)
image = tf.decode_base64(image)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.reshape(image, [-1])
image = tf.reduce_max(image)
return image
model_dir = tempfile.mkdtemp()
with tf.Session(graph=tf.Graph()) as sess:
record_defaults = [
tf.constant([0], dtype=tf.int64),
tf.constant([0.0], dtype=tf.float32),
tf.constant([''], dtype=tf.string),
tf.constant([''], dtype=tf.string),
]
placeholder = tf.placeholder(dtype=tf.string, shape=(None,), name='csv_input_placeholder')
key_tensor, num_tensor, text_tensor, img_tensor = tf.decode_csv(placeholder, record_defaults)
text_tensor = tf.string_to_number(text_tensor, tf.float32)
img_tensor = tf.map_fn(_decode_jpg, img_tensor, back_prop=False, dtype=tf.uint8)
img_tensor = tf.cast(img_tensor, tf.float32)
stacked = tf.stack([num_tensor, text_tensor, img_tensor])
min_tensor = tf.reduce_min(stacked, axis=0)
max_tensor = tf.reduce_max(stacked, axis=0)
predict_input_tensor = tf.saved_model.utils.build_tensor_info(placeholder)
predict_signature_inputs = {"input": predict_input_tensor}
predict_output_tensor1 = tf.saved_model.utils.build_tensor_info(min_tensor)
predict_output_tensor2 = tf.saved_model.utils.build_tensor_info(max_tensor)
predict_key_tensor = tf.saved_model.utils.build_tensor_info(key_tensor)
predict_signature_outputs = {
'key': predict_key_tensor,
'var1': predict_output_tensor1,
'var2': predict_output_tensor2
}
predict_signature_def = (
tf.saved_model.signature_def_utils.build_signature_def(
predict_signature_inputs, predict_signature_outputs,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
)
signature_def_map = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: predict_signature_def
}
model_dir = os.path.join(self._test_dir, dir_name)
builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map=signature_def_map)
builder.save(False)
return model_dir
评论列表
文章目录