runtime.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:entity_binding 作者: JasperGuo 项目源码 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            tag_predictions, segment_length_predictions, feed_dict = self._test_model.predict(batch)
            tag_predictions, segment_length_predictions = self._session.run(
                (tag_predictions, segment_length_predictions,),
                feed_dict=feed_dict
            )

            correct += self._check_predictions(
                tag_predictions=tag_predictions,
                segment_length_predictions=segment_length_predictions,
                ground_truth=batch.ground_truth,
                ground_truth_segment_length=batch.ground_truth_segment_length,
                ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
                question_length=batch.questions_length
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    tag_predictions=tag_predictions,
                    segment_length_predictions=segment_length_predictions
                )

        accuracy = float(correct) / float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号