def train(self):
"""Broadcasts model information to children and signals them to start training.
Receive messages from workers and processes each message until training is done.
When finished, signal the parent process that training is complete.
"""
self.check_sanity()
self.bcast_weights( comm=self.child_comm )
self.init_callbacks(for_worker=self.has_parent)
self.callbacks.on_train_begin()
self.signal_children()
status = MPI.Status()
self.running_workers = list(range(1, self.num_workers+1))
self.waiting_workers_list = []
self.epoch = 0
self.callbacks.on_epoch_begin(self.epoch)
while self.running_workers:
self.recv_any_from_child(status)
self.process_message( status )
if (not self.stop_training) and self.callback_model.stop_training:
self.shut_down_workers()
self.stop_training = True
print ("MPIMaster {0:d} done training".format(self.rank))
# If we did not finish the last epoch, validate one more time.
# (this happens if the batch size does not divide the dataset size)
if self.epoch < self.num_epochs:
epoch_logs = self.validate()
self.callbacks.on_epoch_end(self.epoch, epoch_logs)
self.histories[str(self.rank)] = self.model.history.history
self.send_exit_to_parent()
self.callbacks.on_train_end()
self.send_history_to_parent()
if not self.has_parent:
return self.histories
评论列表
文章目录