def evaluate_on_diff_env(env, n_sample_traj, agent):
# this function is used to sample n traj using GAN version of environment
R = 0.0
# reset environment
env.reset_state()
agent.reset_state()
# get initial observation
observations = env(tv(np.zeros((n_sample_traj, env.act_size))))[:, :-2]
for i in range(env.spec.timestep_limit):
act = agent(observations)
obs_rew = env(act)
rewards = obs_rew[:, -2]
ends = obs_rew[:, -1]
observations = obs_rew[:, :-2]
R += F.sum(rewards * (1.0 - ends)) / (-len(rewards) * env.spec.timestep_limit)
return R
评论列表
文章目录