dqn_agent_without_ER.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:stock_dqn_f 作者: wdy06 项目源码 文件源码
def forward(self, state, action, Reward, state_dash, episode_end):

        num_of_batch = state.shape[0]

        Q = self.model.Q_func(state)  # Get Q-value

        # Generate Target Signals
        tmp = self.model_target.Q_func(state_dash)  # Q(s',*)
        tmp = list(map(np.max, tmp.data))  # max_a Q(s',a)
        max_Q_dash = np.asanyarray(tmp, dtype=np.float32)
        target = np.asanyarray(Q.data, dtype=np.float32)

        for i in xrange(num_of_batch):
            if not episode_end:
                tmp_ = Reward + self.gamma * max_Q_dash[i]
            else:
                tmp_ = Reward
            #print action
            action_index = self.action_to_index(action)
            target[i, action_index] = tmp_

        # TD-error clipping
        td = Variable(target) - Q  # TD error
        td_tmp = td.data + 1000.0 * (abs(td.data) <= 1)  # Avoid zero division
        td_clip = td * (abs(td.data) <= 1) + td/abs(td_tmp) * (abs(td.data) > 1)

        zero_val = Variable(np.zeros((self.replay_size, self.num_of_actions), dtype=np.float32))
        loss = F.mean_squared_error(td_clip, zero_val)
        return loss, Q
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号