model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:botcycle 作者: D2KLab 项目源码 文件源码
def __init__(self, model_path, embedding_size, language, nlp):

        # Step 1: restore the meta graph

        with tf.Graph().as_default() as graph:
            saver = tf.train.import_meta_graph(model_path + "model.ckpt.meta")

            self.graph = graph

            # get tensors for inputs and outputs by name
            self.decoder_prediction = graph.get_tensor_by_name('decoder_prediction:0')
            self.intent = graph.get_tensor_by_name('intent:0')
            self.words_inputs = graph.get_tensor_by_name('words_inputs:0')
            self.encoder_inputs_actual_length = graph.get_tensor_by_name('encoder_inputs_actual_length:0')
            # redefine the py_func that is not serializable
            def static_wrapper(words):
                return spacy_wrapper(embedding_size, language, nlp, words)

            after_py_func = tf.py_func(static_wrapper, [self.words_inputs], tf.float32, stateful=False)

            # Step 2: restore weights
            self.sess = tf.Session()
            self.sess.run(tf.tables_initializer())
            saver.restore(self.sess, model_path + "model.ckpt")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号