ddpg.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:rllabplusplus 作者: shaneshixiang 项目源码 文件源码
def save(self, checkpoint_dir=None):
        if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()

        pool_file = os.path.join(checkpoint_dir, 'pool.chk')
        if self.save_format == 'pickle':
            pickle_dump(pool_file + '.tmp', self.pool)
        elif self.save_format == 'joblib':
            joblib.dump(self.pool, pool_file + '.tmp', compress=1, cache_size=1e9)
        else: raise NotImplementedError
        shutil.move(pool_file + '.tmp', pool_file)

        checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
        sess = tf.get_default_session()
        saver = tf.train.Saver()
        saver.save(sess, checkpoint_file)

        tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
        if os.path.isfile(tabular_file):
            tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
            shutil.copy(tabular_file, tabular_chk_file)

        logger.log('Saved to checkpoint %s'%checkpoint_file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号