def restore(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()
checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
if os.path.isfile(checkpoint_file + '.meta'):
sess = tf.get_default_session()
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
if os.path.isfile(tabular_chk_file):
tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
logger.remove_tabular_output(tabular_file)
shutil.copy(tabular_chk_file, tabular_file)
logger.add_tabular_output(tabular_file)
pool_file = os.path.join(checkpoint_dir, 'pool.chk')
if self.save_format == 'pickle':
pickle_load(pool_file)
elif self.save_format == 'joblib':
self.pool = joblib.load(pool_file)
else: raise NotImplementedError
logger.log('Restored from checkpoint %s'%checkpoint_file)
else:
logger.log('No checkpoint %s'%checkpoint_file)
评论列表
文章目录