dqn_agent.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dist-dqn 作者: viswanathgs 项目源码 文件源码
def _get_minibatch_feed_dict(self, target_q_values, 
                               non_terminal_minibatch, terminal_minibatch):
    """
    Helper to construct the feed_dict for train_op. Takes the non-terminal and 
    terminal minibatches as well as the max q-values computed from the target
    network for non-terminal states. Computes the expected q-values based on
    discounted future reward.

    @return: feed_dict to be used for train_op
    """
    assert len(target_q_values) == len(non_terminal_minibatch)

    states = []
    expected_q = []
    actions = []

    # Compute expected q-values to plug into the loss function
    minibatch = itertools.chain(non_terminal_minibatch, terminal_minibatch)
    for item, target_q in zip_longest(minibatch, target_q_values, fillvalue=0):
      state, action, reward, _, _ = item
      states.append(state)
      # target_q will be 0 for terminal states due to fillvalue in zip_longest
      expected_q.append(reward + self.config.reward_discount * target_q)
      actions.append(utils.one_hot(action, self.env.action_space.n))

    return {
      self.network.x_placeholder: states, 
      self.network.q_placeholder: expected_q,
      self.network.action_placeholder: actions,
    }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号