def __init__(self,
batch_size: int,
output_folder: str,
optimized_metric: str = C.PERPLEXITY,
use_tensorboard: bool = False,
cp_decoder: Optional[checkpoint_decoder.CheckpointDecoder] = None) -> None:
self.output_folder = output_folder
# stores dicts of metric names & values for each checkpoint
self.metrics = [] # type: List[Dict]
self.metrics_filename = os.path.join(output_folder, C.METRICS_NAME)
self.best_checkpoint = 0
self.start_tic = time.time()
self.summary_writer = None
if use_tensorboard:
import tensorboard # pylint: disable=import-error
log_dir = os.path.join(output_folder, C.TENSORBOARD_NAME)
if os.path.exists(log_dir):
logger.info("Deleting existing tensorboard log dir %s", log_dir)
shutil.rmtree(log_dir)
logger.info("Logging training events for Tensorboard at '%s'", log_dir)
self.summary_writer = tensorboard.FileWriter(log_dir)
self.cp_decoder = cp_decoder
self.ctx = mp.get_context('spawn') # type: ignore
self.decoder_metric_queue = self.ctx.Queue()
self.decoder_process = None # type: Optional[mp.Process]
utils.check_condition(optimized_metric in C.METRICS, "Unsupported metric: %s" % optimized_metric)
if optimized_metric == C.BLEU:
utils.check_condition(self.cp_decoder is not None, "%s requires CheckpointDecoder" % C.BLEU)
self.optimized_metric = optimized_metric
self.validation_best = C.METRIC_WORST[self.optimized_metric]
logger.info("Early stopping by optimizing '%s'", self.optimized_metric)
self.tic = 0
评论列表
文章目录