train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:dqn 作者: prabhatnagarajan 项目源码 文件源码
def validate(ale, agent, no_op_max, hist_len, reward_history, act_rpt):
    ale.reset_game()
    seq = list()
    preprocess_stack = deque([], 2)
    perform_no_ops(ale, no_op_max, preprocess_stack, seq)
    total_reward = 0
    num_rewards = 0
    num_episodes = 0
    episode_reward = 0
    eval_time = time()
    for _ in range(EVAL_STEPS):
        state = get_state(seq, hist_len)
        action = agent.eGreedy_action(state, TEST_EPSILON)
        reward = 0
        for i in range(act_rpt):
            reward += ale.act(action)
            preprocess_stack.append(ale.getScreenRGB())
        img = pp.preprocess(preprocess_stack[0], preprocess_stack[1])
        seq.append(img)
        episode_reward += reward
        if not (reward == 0):
            num_rewards += 1
        if ale.game_over():
            total_reward += episode_reward
            episode_reward = 0
            num_episodes += 1
            ale.reset_game()
            seq = list()
            perform_no_ops(ale, no_op_max, preprocess_stack, seq)
    total_reward = float(total_reward)/float(max(1, num_episodes))
    if len(reward_history) == 0 or total_reward > max(reward_history):
        agent.update_best_scoring_network()
    reward_history.append(total_reward)

#Returns hist_len most preprocessed frames and memory
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号