def __init__(self,
config: model.ModelConfig,
context: List[mx.context.Context],
train_iter: data_io.BaseParallelSampleIter,
bucketing: bool,
lr_scheduler,
gradient_compression_params: Optional[Dict[str, Any]] = None) -> None:
super().__init__(config)
self.context = context
self.lr_scheduler = lr_scheduler
self.bucketing = bucketing
self.gradient_compression_params = gradient_compression_params
self._build_model_components()
self.module = self._build_module(train_iter)
self.training_monitor = None # type: Optional[callback.TrainingMonitor]
评论列表
文章目录