pridict.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:deeplearning 作者: fanfanfeng 项目源码 文件源码
def predict(text):
    words = jieba.cut(text)
    words = " ".join(words)
    index2label = {i: l.strip() for i, l in enumerate(tv_classfication.label_list)}

    word2vec_model = Word2Vec.load(tv_classfication.word2vec_path)
    text_converter = data_convert.SimpleTextConverter(word2vec_model, 80, None)
    x_test = []
    for doc, _ in text_converter.transform_to_ids([words]):
        x_test.append(doc)

    x_test = np.array(x_test)

    graph = tf.Graph()
    with graph.as_default(),tf.Session() as sess:
        model = bi_lstm_model.Bi_lstm()
        model.restore_model(sess)

        print(tv_classfication.index2label.get(model.predict(sess,x_test)[0]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号