def test(self, data_iterator, is_log=False):
tqdm.write("Testing...")
total = 0
correct = 0
file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
for i in tqdm(range(data_iterator.batch_per_epoch)):
batch = data_iterator.get_batch()
predictions, feed_dict = self._test_model.predict(batch)
predictions = self._session.run(predictions, feed_dict=feed_dict)
correct += self._check_predictions(
predictions=predictions,
ground_truth=batch.ground_truth
)
total += batch.size
if is_log:
self.log(
file=file,
batch=batch,
predictions=predictions
)
accuracy = float(correct)/float(total)
tqdm.write("test_acc: %f" % accuracy)
return accuracy
评论列表
文章目录