def eval_performance(rom, p_func, n_runs):
assert n_runs > 1, 'Computing stdev requires at least two runs'
scores = []
for i in range(n_runs):
env = ale.ALE(rom, treat_life_lost_as_terminal=False)
test_r = 0
while not env.is_terminal:
s = chainer.Variable(np.expand_dims(dqn_phi(env.state), 0))
pout = p_func(s)
a = pout.action_indices[0]
test_r += env.receive_action(a)
scores.append(test_r)
print('test_{}:'.format(i), test_r)
mean = statistics.mean(scores)
median = statistics.median(scores)
stdev = statistics.stdev(scores)
return mean, median, stdev
评论列表
文章目录