def test(self, data, session):
ys_true = collections.deque([])
ys_pred = collections.deque([])
for batch in data:
y_pred = tf.argmax(self.get_output(), 1)
y_true = self.labels
feed_dict = {self.labels: batch[0].root_labels}
feed_dict.update(self.tree_lstm.get_feed_dict(batch[0]))
y_pred, y_true = session.run([y_pred, y_true], feed_dict=feed_dict)
ys_true += y_true.tolist()
ys_pred += y_pred.tolist()
ys_true = list(ys_true)
ys_pred = list(ys_pred)
score = metrics.accuracy_score(ys_true, ys_pred)
print "Accuracy", score
#print "Recall", metrics.recall_score(ys_true, ys_pred)
#print "f1_score", metrics.f1_score(ys_true, ys_pred)
print "confusion_matrix"
print metrics.confusion_matrix(ys_true, ys_pred)
return score
评论列表
文章目录