ddpg.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rllabplusplus 作者: shaneshixiang 项目源码 文件源码
def restore(self, checkpoint_dir=None):
        if checkpoint_dir is None: checkpoint_dir = logger.get_snapshot_dir()
        checkpoint_file = os.path.join(checkpoint_dir, 'params.chk')
        if os.path.isfile(checkpoint_file + '.meta'):
            sess = tf.get_default_session()
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_file)

            tabular_chk_file = os.path.join(checkpoint_dir, 'progress.csv.chk')
            if os.path.isfile(tabular_chk_file):
                tabular_file = os.path.join(checkpoint_dir, 'progress.csv')
                logger.remove_tabular_output(tabular_file)
                shutil.copy(tabular_chk_file, tabular_file)
                logger.add_tabular_output(tabular_file)

            pool_file = os.path.join(checkpoint_dir, 'pool.chk')
            if self.save_format == 'pickle':
                pickle_load(pool_file)
            elif self.save_format == 'joblib':
                self.pool = joblib.load(pool_file)
            else: raise NotImplementedError

            logger.log('Restored from checkpoint %s'%checkpoint_file)
        else:
            logger.log('No checkpoint %s'%checkpoint_file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号