def eval_performance(rom, p_func, n_runs):
assert n_runs > 1, 'Computing stdev requires at least two runs'
scores = []
for i in range(n_runs):
env = ale.ALE(rom, treat_life_lost_as_terminal=False)
test_r = 0
while not env.is_terminal:
s = util.dqn_phi(env.state)
pout = p_func(s)
a = util.categorical_sample(pout)
test_r += env.receive_action(a)
scores.append(test_r)
print 'test_',i,':',test_r
mean = statistics.mean(scores)
median = statistics.median(scores)
stdev = statistics.stdev(scores)
return mean, median, stdev
评论列表
文章目录