predict.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:cnn-text-classification-tf-Chinese-py3 作者: MakDon 项目源码 文件源码
def predict(self, x_raw):
        x_raw = list(x_raw)
        x_raw = [s.strip() for s in x_raw]
        x_raw = [list(s) for s in x_raw]
        x_pad,_ = data_helpers.pad_sentences(x_raw,sequence_length)
        x_test = np.array([[vocabulary.get(word,0) for word in sentence] for sentence in x_pad])

        # Get the placeholders from the graph by name
        input_x = self.graph.get_operation_by_name("input_x").outputs[0]
        # input_y = graph.get_operation_by_name("input_y").outputs[0]
        dropout_keep_prob = self.graph.get_operation_by_name("dropout_keep_prob").outputs[0]

        # Tensors we want to evaluate
        predictions = self.graph.get_operation_by_name("output/predictions").outputs[0]

        # Generate batches for one epoch
        batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)

        # Collect the predictions here
        all_predictions = []

        for x_test_batch in batches:
            batch_predictions = self.sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
            all_predictions = np.concatenate([all_predictions, batch_predictions])

        return all_predictions

#test predict
#========================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号