def load_checkpoint(self, directory: str, train_iter: data_io.BaseParallelSampleIter) -> _TrainingState:
"""
Loads the full training state from disk. This includes optimizer,
random number generators and everything needed. Note that params
should have been loaded already by the initializer.
:param directory: directory where the state has been saved.
:param train_iter: training data iterator.
"""
# Optimzer state (from mxnet)
opt_state_fname = os.path.join(directory, C.OPT_STATES_LAST)
self.load_optimizer_states(opt_state_fname)
# State of the bucket iterator
train_iter.load_state(os.path.join(directory, C.BUCKET_ITER_STATE_NAME))
# RNG states: python's random and np.random provide functions for
# storing the state, mxnet does not, but inside our code mxnet's RNG is
# not used AFAIK
with open(os.path.join(directory, C.RNG_STATE_NAME), "rb") as fp:
random.setstate(pickle.load(fp))
np.random.set_state(pickle.load(fp))
# Monitor state, in order to get the full information about the metrics
self.training_monitor.load_state(os.path.join(directory, C.MONITOR_STATE_NAME))
# And our own state
return self.load_state(os.path.join(directory, C.TRAINING_STATE_NAME))
评论列表
文章目录