process.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:mpi_learn 作者: duanders 项目源码 文件源码
def train(self):
        """Broadcasts model information to children and signals them to start training.
            Receive messages from workers and processes each message until training is done.
            When finished, signal the parent process that training is complete.
        """
        self.check_sanity()
        self.bcast_weights( comm=self.child_comm )
        self.init_callbacks(for_worker=self.has_parent)
        self.callbacks.on_train_begin()
        self.signal_children()

        status = MPI.Status()
        self.running_workers = list(range(1, self.num_workers+1))
        self.waiting_workers_list = []

        self.epoch = 0
        self.callbacks.on_epoch_begin(self.epoch)
        while self.running_workers:
            self.recv_any_from_child(status)
            self.process_message( status )
            if (not self.stop_training) and self.callback_model.stop_training:
                self.shut_down_workers()
                self.stop_training = True
        print ("MPIMaster {0:d} done training".format(self.rank))
        # If we did not finish the last epoch, validate one more time.
        # (this happens if the batch size does not divide the dataset size)
        if self.epoch < self.num_epochs:
            epoch_logs = self.validate()
            self.callbacks.on_epoch_end(self.epoch, epoch_logs)
        self.histories[str(self.rank)] = self.model.history.history
        self.send_exit_to_parent()
        self.callbacks.on_train_end()
        self.send_history_to_parent()
        if not self.has_parent:
            return self.histories
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号