def save_itr_params(itr, params):
global _logger_info
if _snapshot_dir:
if _snapshot_mode == 'all':
file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
joblib.dump(params, file_name, compress=3)
elif _snapshot_mode == 'last':
# override previous params
file_name = osp.join(_snapshot_dir, 'params.pkl')
joblib.dump(params, file_name, compress=3)
elif _snapshot_mode == 'last_best':
# saves best and last params
last_file_name = osp.join(_snapshot_dir, 'params.pkl')
joblib.dump(params, last_file_name, compress=3)
_logger_info["lastReward"] = get_last_tabular("AverageReturn")
_logger_info["lastItr"] = get_last_tabular("Iteration")
if "bestReward" not in _logger_info or _logger_info["bestReward"] < _logger_info["lastReward"]:
best_file_name = osp.join(_snapshot_dir, 'params_best.pkl')
shutil.copy(last_file_name, best_file_name)
_logger_info["bestReward"] = _logger_info["lastReward"]
_logger_info["bestItr"] = _logger_info["lastItr"]
elif _snapshot_mode == 'last_all_best':
# saves last and all best params
last_file_name = osp.join(_snapshot_dir, 'params.pkl')
joblib.dump(params, last_file_name, compress=3)
_logger_info["lastReward"] = get_last_tabular("AverageReturn")
_logger_info["lastItr"] = get_last_tabular("Iteration")
if "bestReward" not in _logger_info or _logger_info["bestReward"] < _logger_info["lastReward"]:
best_file_name = osp.join(_snapshot_dir, 'params_best_%08d.pkl' % itr)
shutil.copy(last_file_name, best_file_name)
_logger_info["bestReward"] = _logger_info["lastReward"]
_logger_info["bestItr"] = _logger_info["lastItr"]
elif _snapshot_mode == "gap":
if itr % _snapshot_gap == 0:
file_name = osp.join(_snapshot_dir, 'itr_%d.pkl' % itr)
joblib.dump(params, file_name, compress=3)
elif _snapshot_mode == 'none':
pass
else:
raise NotImplementedError
评论列表
文章目录