test_envs_semantics.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:AI-Fight-the-Landlord 作者: YoungGer 项目源码 文件源码
def test_env_semantics(spec):
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)

  if spec.id not in rollout_dict:
    if not spec.nondeterministic:
      logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
    return

  logger.info("Testing rollout for {} environment...".format(spec.id))

  observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)

  errors = []
  if rollout_dict[spec.id]['observations'] != observations_now:
    errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
  if rollout_dict[spec.id]['actions'] != actions_now:
    errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
  if rollout_dict[spec.id]['rewards'] != rewards_now:
    errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
  if rollout_dict[spec.id]['dones'] != dones_now:
    errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
  if len(errors):
    for error in errors:
      logger.warn(error)
    raise ValueError(errors)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号