def predict(text):
words = jieba.cut(text)
words = " ".join(words)
index2label = {i: l.strip() for i, l in enumerate(tv_classfication.label_list)}
word2vec_model = Word2Vec.load(tv_classfication.word2vec_path)
text_converter = data_convert.SimpleTextConverter(word2vec_model, 80, None)
x_test = []
for doc, _ in text_converter.transform_to_ids([words]):
x_test.append(doc)
x_test = np.array(x_test)
graph = tf.Graph()
with graph.as_default(),tf.Session() as sess:
model = bi_lstm_model.Bi_lstm()
model.restore_model(sess)
print(tv_classfication.index2label.get(model.predict(sess,x_test)[0]))
评论列表
文章目录