multistate_dqn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chainer_frmqn 作者: okdshin 项目源码 文件源码
def calc_loss_recurrent(self, frames, actions, rewards, done_list, size_list):
        # TODO self.max_step -> max_step
        s = Variable(frames.astype(np.float32))

        self.model_target.reset_state() # Refresh model_target's state
        self.model_target.q_function(s[0]) # Update target model initial state

        target_q = self.xp.zeros((self.max_step, self.replay_batch_size), dtype=np.float32)
        selected_q_tuple = [None for _ in range(self.max_step)]

        for frame in range(0, self.max_step):
            q = self.model.q_function(s[frame])
            q_dash = self.model_target.q_function(s[frame+1])  # Q(s',*): shape is (batch_size, action_num)
            max_q_dash = q_dash.data.max(axis=1) # max_a Q(s',a): shape is (batch_size,)
            if self.clipping:
                rs = self.xp.sign(rewards[frame])
            else:
                rs = rewards[frame]
            target_q[frame] = rs + self.xp.logical_not(done_list[frame]).astype(np.int)*(self.gamma*max_q_dash)
            selected_q_tuple[frame] = F.select_item(q, actions[frame].astype(np.int))

        enable = self.xp.broadcast_to(self.xp.arange(self.max_step), (self.replay_batch_size, self.max_step))
        size_list = self.xp.expand_dims(cuda.to_gpu(size_list), -1)
        enable = (enable < size_list).T

        selected_q = F.concat(selected_q_tuple, axis=0)

        # element-wise huber loss
        huber_loss = F.huber_loss(
                F.expand_dims(F.flatten(target_q), axis=1),
                F.expand_dims(selected_q, axis=1), delta=1.0)
        huber_loss = F.reshape(huber_loss, enable.shape)

        zeros = self.xp.zeros(enable.shape, dtype=np.float32)
        loss = F.sum(F.where(enable, huber_loss, zeros)) #/ self.replay_batch_size
        #print("loss", loss.data)

        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号