atari-qlearning.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:reinforcement-learning 作者: cgnicholls 项目源码 文件源码
def train(sess, q_network, target_network, observations):
    # Sample a minibatch to train on
    mini_batch = random.sample(observations, MINI_BATCH_SIZE)

    states = [d['state'] for d in mini_batch]
    actions = [d['action'] for d in mini_batch]
    rewards = [d['reward'] for d in mini_batch]
    next_states = [d['next_state'] for d in mini_batch]
    terminal = np.array([d['terminal'] for d in mini_batch])

    # Compute Q(s', a'; theta'), where theta' are the parameters for the target
    # network. This is an unbiased estimator for y_i as in eqn 2 in the DQN
    # paper.
    next_q = sess.run(target_network.output_layer, feed_dict={
        target_network.input_layer: next_states
    })

    target_q = rewards + np.invert(terminal).astype('float32') * DISCOUNT_FACTOR * np.max(next_q, axis=1)

    one_hot_actions = compute_one_hot_actions(actions)

    # Train the q-network (i.e. the parameters theta).
    q_network.train(sess, states, one_hot_actions, target_q)

# Return a one hot vector with a 1 at the index for the action.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号