nary_tree_lstm.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:treelstm 作者: nicolaspi 项目源码 文件源码
def test(self, data, session):
        ys_true = collections.deque([])
        ys_pred = collections.deque([])
        for batch in data:
            y_pred = tf.argmax(self.get_output(), 1)
            y_true = self.labels
            feed_dict = {self.labels: batch[0].root_labels}
            feed_dict.update(self.tree_lstm.get_feed_dict(batch[0]))
            y_pred, y_true = session.run([y_pred, y_true], feed_dict=feed_dict)
            ys_true += y_true.tolist()
            ys_pred += y_pred.tolist()
        ys_true = list(ys_true)
        ys_pred = list(ys_pred)
        score = metrics.accuracy_score(ys_true, ys_pred)
        print "Accuracy", score
        #print "Recall", metrics.recall_score(ys_true, ys_pred)
        #print "f1_score", metrics.f1_score(ys_true, ys_pred)
        print "confusion_matrix"
        print metrics.confusion_matrix(ys_true, ys_pred)
        return score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号