def _get_minibatch_feed_dict(self, target_q_values,
non_terminal_minibatch, terminal_minibatch):
"""
Helper to construct the feed_dict for train_op. Takes the non-terminal and
terminal minibatches as well as the max q-values computed from the target
network for non-terminal states. Computes the expected q-values based on
discounted future reward.
@return: feed_dict to be used for train_op
"""
assert len(target_q_values) == len(non_terminal_minibatch)
states = []
expected_q = []
actions = []
# Compute expected q-values to plug into the loss function
minibatch = itertools.chain(non_terminal_minibatch, terminal_minibatch)
for item, target_q in zip_longest(minibatch, target_q_values, fillvalue=0):
state, action, reward, _, _ = item
states.append(state)
# target_q will be 0 for terminal states due to fillvalue in zip_longest
expected_q.append(reward + self.config.reward_discount * target_q)
actions.append(utils.one_hot(action, self.env.action_space.n))
return {
self.network.x_placeholder: states,
self.network.q_placeholder: expected_q,
self.network.action_placeholder: actions,
}
评论列表
文章目录