gan_rl_fitter.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:gan-rl 作者: iaroslav-ai 项目源码 文件源码
def evaluate_on_diff_env(env, n_sample_traj, agent):
    # this function is used to sample n traj using GAN version of environment
    R = 0.0

    # reset environment
    env.reset_state()
    agent.reset_state()

    # get initial observation
    observations = env(tv(np.zeros((n_sample_traj, env.act_size))))[:, :-2]

    for i in range(env.spec.timestep_limit):
        act = agent(observations)
        obs_rew = env(act)
        rewards = obs_rew[:, -2]
        ends = obs_rew[:, -1]
        observations = obs_rew[:, :-2]
        R += F.sum(rewards * (1.0 - ends)) / (-len(rewards) * env.spec.timestep_limit)

    return R
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号