ntuple_experience_replay.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:categorical-dqn 作者: floringogianu 项目源码 文件源码
def _batch2torch(self, batch, batch_size):
        """ List of transitions -> Batch of transitions -> pytorch tensors.

            Returns:
                states: torch.size([batch_size, hist_len, w, h])
                a/r/d: torch.size([batch_size, 1])
        """
        # check-out pytorch dqn tutorial.
        # (t1, t2, ... tn) -> t((s1, s2, ..., sn), (a1, a2, ... an) ...)
        batch = BatchTransition(*zip(*batch))

        # lists to tensors
        state_batch = torch.cat(batch.state, 0).type(self.dtype.FT) / 255
        action_batch = self.dtype.LT(batch.action).unsqueeze(1)
        reward_batch = self.dtype.FT(batch.reward).unsqueeze(1)
        next_state_batch = torch.cat(batch.state_, 0).type(self.dtype.FT) / 255
        # [False, False, True, False] -> [1, 1, 0, 1]::ByteTensor
        mask = 1 - self.dtype.BT(batch.done).unsqueeze(1)

        return [batch_size, state_batch, action_batch, reward_batch,
                next_state_batch, mask]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号