runtime.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:entity_binding 作者: JasperGuo 项目源码 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            predictions, feed_dict = self._test_model.predict(batch)
            predictions = self._session.run(predictions, feed_dict=feed_dict)

            correct += self._check_predictions(
                predictions=predictions,
                ground_truth=batch.ground_truth
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    predictions=predictions
                )

        accuracy = float(correct)/float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号