def do(self, which_callback, *args):
iterations_done = self.main_loop.status['iterations_done']
if self.burnin <= iterations_done:
# Save the model here
iterations_done = self.main_loop.status['iterations_done']
filename = os.path.join(
self.saveto, 'params_iter{}.npz'.format(iterations_done))
s = signal.signal(signal.SIGINT, signal.SIG_IGN)
logger.info(" Incremental dump {}".format(filename))
params_to_save = []
for cg_name in self.main_loop.models.keys():
params_to_save.append(
self.main_loop.models[cg_name].get_param_values())
params_to_save = merge(params_to_save)
secure_numpy_save(params_to_save, filename)
if self.save_iter_state:
filename_is = os.path.join(
self.saveto,
'iterations_state_iter{}.pkl'.format(iterations_done))
logger.info(" Incremental dump {}".format(filename_is))
secure_pickle_dump(self.main_loop.iteration_state, filename_is)
if self.save_log:
filename_log = os.path.join(
self.saveto,
'log_iter{}'.format(iterations_done))
logger.info(" Incremental dump {}".format(filename_log))
secure_pickle_dump(self.main_loop.log, filename_log)
signal.signal(signal.SIGINT, s)
评论列表
文章目录