def train(sess, q_network, target_network, observations):
# Sample a minibatch to train on
mini_batch = random.sample(observations, MINI_BATCH_SIZE)
states = [d['state'] for d in mini_batch]
actions = [d['action'] for d in mini_batch]
rewards = [d['reward'] for d in mini_batch]
next_states = [d['next_state'] for d in mini_batch]
terminal = np.array([d['terminal'] for d in mini_batch])
# Compute Q(s', a'; theta'), where theta' are the parameters for the target
# network. This is an unbiased estimator for y_i as in eqn 2 in the DQN
# paper.
next_q = sess.run(target_network.output_layer, feed_dict={
target_network.input_layer: next_states
})
target_q = rewards + np.invert(terminal).astype('float32') * DISCOUNT_FACTOR * np.max(next_q, axis=1)
one_hot_actions = compute_one_hot_actions(actions)
# Train the q-network (i.e. the parameters theta).
q_network.train(sess, states, one_hot_actions, target_q)
# Return a one hot vector with a 1 at the index for the action.
atari-qlearning.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录