callback.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def __init__(self,
                 batch_size: int,
                 output_folder: str,
                 optimized_metric: str = C.PERPLEXITY,
                 use_tensorboard: bool = False,
                 cp_decoder: Optional[checkpoint_decoder.CheckpointDecoder] = None) -> None:
        self.output_folder = output_folder
        # stores dicts of metric names & values for each checkpoint
        self.metrics = []  # type: List[Dict]
        self.metrics_filename = os.path.join(output_folder, C.METRICS_NAME)
        self.best_checkpoint = 0
        self.start_tic = time.time()
        self.summary_writer = None
        if use_tensorboard:
            import tensorboard  # pylint: disable=import-error
            log_dir = os.path.join(output_folder, C.TENSORBOARD_NAME)
            if os.path.exists(log_dir):
                logger.info("Deleting existing tensorboard log dir %s", log_dir)
                shutil.rmtree(log_dir)
            logger.info("Logging training events for Tensorboard at '%s'", log_dir)
            self.summary_writer = tensorboard.FileWriter(log_dir)
        self.cp_decoder = cp_decoder
        self.ctx = mp.get_context('spawn')  # type: ignore
        self.decoder_metric_queue = self.ctx.Queue()
        self.decoder_process = None  # type: Optional[mp.Process]
        utils.check_condition(optimized_metric in C.METRICS, "Unsupported metric: %s" % optimized_metric)
        if optimized_metric == C.BLEU:
            utils.check_condition(self.cp_decoder is not None, "%s requires CheckpointDecoder" % C.BLEU)
        self.optimized_metric = optimized_metric
        self.validation_best = C.METRIC_WORST[self.optimized_metric]
        logger.info("Early stopping by optimizing '%s'", self.optimized_metric)
        self.tic = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号