def test(self, data_iterator, is_log=False):
tqdm.write("Testing...")
total = 0
correct = 0
file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
for i in tqdm(range(data_iterator.batch_per_epoch)):
batch = data_iterator.get_batch()
tag_predictions, segment_length_predictions, feed_dict = self._test_model.predict(batch)
tag_predictions, segment_length_predictions = self._session.run(
(tag_predictions, segment_length_predictions,),
feed_dict=feed_dict
)
correct += self._check_predictions(
tag_predictions=tag_predictions,
segment_length_predictions=segment_length_predictions,
ground_truth=batch.ground_truth,
ground_truth_segment_length=batch.ground_truth_segment_length,
ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
question_length=batch.questions_length
)
total += batch.size
if is_log:
self.log(
file=file,
batch=batch,
tag_predictions=tag_predictions,
segment_length_predictions=segment_length_predictions
)
accuracy = float(correct) / float(total)
tqdm.write("test_acc: %f" % accuracy)
return accuracy
评论列表
文章目录