def test(self):
ckpt = tf.train.get_checkpoint_state(self.model_dir)
if ckpt and ckpt.model_checkpoint_path:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
else:
print("...no checkpoint found...")
if self.isInteractive:
self.interactive()
else:
testP, testS, testQ, testA = vectorize_data(self.testData, self.word_idx, self.sentence_size, self.batch_size, self.n_cand, self.memory_size)
n_test = len(testS)
print("Testing Size", n_test)
test_preds=self.batch_predict(testP,testS,testQ,n_test)
test_acc = metrics.accuracy_score(test_preds, testA)
print("Testing Accuracy:", test_acc)
# print(testA)
# for pred in test_preds:
# print(pred, self.indx2candid[pred])
评论列表
文章目录