def save(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()
pool_file = os.path.join(checkpoint_dir, 'pool.chk')
if self.save_format == 'pickle':
pickle_dump(pool_file + '.tmp', self.pool)
elif self.save_format == 'joblib':
joblib.dump(self.pool, pool_file + '.tmp', compress=1, cache_size=1e9)
else: raise NotImplementedError
shutil.move(pool_file + '.tmp', pool_file)
checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
sess = tf.get_default_session()
saver = tf.train.Saver()
saver.save(sess, checkpoint_file)
tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
if os.path.isfile(tabular_file):
tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
shutil.copy(tabular_file, tabular_chk_file)
logger.log('Saved to checkpoint %s'%checkpoint_file)
评论列表
文章目录