logger.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:rllabplusplus 作者: shaneshixiang 项目源码 文件源码
def save_itr_params(itr, params):
    global _logger_info
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'last_best':
            # saves best and last params
            last_file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, last_file_name, compress=3)
            _logger_info["lastReward"] = get_last_tabular("AverageReturn")
            _logger_info["lastItr"] = get_last_tabular("Iteration")
            if "bestReward" not in _logger_info or _logger_info["bestReward"] < _logger_info["lastReward"]:
                best_file_name = osp.join(_snapshot_dir, 'params_best.pkl')
                shutil.copy(last_file_name, best_file_name)
                _logger_info["bestReward"] = _logger_info["lastReward"]
                _logger_info["bestItr"] = _logger_info["lastItr"]
        elif _snapshot_mode == 'last_all_best':
            # saves last and all best params
            last_file_name = osp.join(_snapshot_dir, 'params.pkl')
            joblib.dump(params, last_file_name, compress=3)
            _logger_info["lastReward"] = get_last_tabular("AverageReturn")
            _logger_info["lastItr"] = get_last_tabular("Iteration")
            if "bestReward" not in _logger_info or _logger_info["bestReward"] < _logger_info["lastReward"]:
                best_file_name = osp.join(_snapshot_dir, 'params_best_%08d.pkl' % itr)
                shutil.copy(last_file_name, best_file_name)
                _logger_info["bestReward"] = _logger_info["lastReward"]
                _logger_info["bestItr"] = _logger_info["lastItr"]
        elif _snapshot_mode == "gap":
            if itr % _snapshot_gap == 0:
                file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
                joblib.dump(params, file_name, compress=3)
        elif _snapshot_mode == 'none':
            pass
        else:
            raise NotImplementedError
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号