multistate_dqn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer_frmqn 作者: okdshin 项目源码 文件源码
def calc_loss(self, state, state_dash, actions, rewards, done_list):
        assert(state.shape == state_dash.shape)
        s = state.reshape((state.shape[0], reduce(lambda x, y: x*y, state.shape[1:]))).astype(np.float32)
        s_dash = state_dash.reshape((state.shape[0], reduce(lambda x, y: x*y, state.shape[1:]))).astype(np.float32)
        q = self.model.q_function(s)

        q_dash = self.model_target.q_function(s_dash)  # Q(s',*)
        max_q_dash = np.asarray(list(map(np.max, q_dash.data)), dtype=np.float32) # max_a Q(s',a)

        target = q.data.copy()
        for i in range(self.replay_batch_size):
            assert(self.replay_batch_size == len(done_list))
            r = np.sign(rewards[i]) if self.clipping else rewards[i]
            if done_list[i]:
                discounted_sum = r
            else:
                discounted_sum = r + self.gamma * max_q_dash[i]
            assert(self.replay_batch_size == len(actions))
            target[i, actions[i]] = discounted_sum

        loss = F.sum(F.huber_loss(Variable(target), q, delta=1.0)) #/ self.replay_batch_size
        return loss, q
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号