runtime.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:entity_binding 作者: JasperGuo 项目源码 文件源码
def log(self, file, batch, tag_predictions, segment_length_predictions):

        unfold_predictions, unfold_ground_truth = self._process_predictions(
            tag_predictions=tag_predictions,
            segment_length_predictions=segment_length_predictions,
            ground_truth=batch.ground_truth,
            ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
            ground_truth_segment_length=batch.ground_truth_segment_length,
            question_length=batch.questions_length
        )

        with open(file, "a") as f:
            string = ""
            for tt, ts, pt, ps, qid, cv, table_id, unfold_p, unfold_t in zip(
                    batch.ground_truth,
                    batch.ground_truth_segment_length,
                    tag_predictions,
                    segment_length_predictions,
                    batch.questions_ids,
                    batch.cell_value_length,
                    batch.table_map_ids,
                    unfold_predictions,
                    unfold_ground_truth
            ):
                result = np.sum(np.abs(np.array(unfold_p) - np.array(unfold_t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("unfold_t: " + (', '.join([str(i) for i in unfold_t])) + "\n")
                string += ("unfold_p: " + (', '.join([str(i) for i in unfold_p])) + "\n")
                string += ("ts: " + (', '.join([str(i) for i in ts])) + "\n")
                string += ("tt: " + (', '.join([str(i) for i in tt])) + "\n")
                string += ("pt: " + (', '.join([str(i) for i in pt])) + "\n")
                string += ("ps: " + (', '.join([str(i) for i in ps])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号