def eval_performance(process_idx, make_env, model, phi, n_runs):
assert n_runs > 1, 'Computing stdev requires at least two runs'
scores = []
for i in range(n_runs):
model.reset_state()
env = make_env(process_idx, test=True)
obs = env.reset()
done = False
test_r = 0
while not done:
s = chainer.Variable(np.expand_dims(phi(obs), 0))
pout, _ = model.pi_and_v(s)
a = pout.action_indices[0]
obs, r, done, info = env.step(a)
test_r += r
scores.append(test_r)
print('test_{}:'.format(i), test_r)
mean = statistics.mean(scores)
median = statistics.median(scores)
stdev = statistics.stdev(scores)
return mean, median, stdev
评论列表
文章目录