def __init__(self, model_path, embedding_size, language, nlp):
# Step 1: restore the meta graph
with tf.Graph().as_default() as graph:
saver = tf.train.import_meta_graph(model_path + "model.ckpt.meta")
self.graph = graph
# get tensors for inputs and outputs by name
self.decoder_prediction = graph.get_tensor_by_name('decoder_prediction:0')
self.intent = graph.get_tensor_by_name('intent:0')
self.words_inputs = graph.get_tensor_by_name('words_inputs:0')
self.encoder_inputs_actual_length = graph.get_tensor_by_name('encoder_inputs_actual_length:0')
# redefine the py_func that is not serializable
def static_wrapper(words):
return spacy_wrapper(embedding_size, language, nlp, words)
after_py_func = tf.py_func(static_wrapper, [self.words_inputs], tf.float32, stateful=False)
# Step 2: restore weights
self.sess = tf.Session()
self.sess.run(tf.tables_initializer())
saver.restore(self.sess, model_path + "model.ckpt")
评论列表
文章目录