wrappers.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:categorical-dqn 作者: floringogianu 项目源码 文件源码
def __init__(self, env, env_type, hist_len, state_dims, cuda=None):
        super(PreprocessFrames, self).__init__(env)

        self.env_type = env_type
        self.state_dims = state_dims
        self.hist_len = hist_len
        self.env_wh = self.env.observation_space.shape[0:2]
        self.env_ch = self.env.observation_space.shape[2]
        self.wxh = self.env_wh[0] * self.env_wh[1]

        # need to find a better way
        if self.env_type == "atari":
            self._preprocess = self._atari_preprocess
        elif self.env_type == "catch":
            self._preprocess = self._catch_preprocess
        print("[Preprocess Wrapper] for %s with state history of %d frames."
              % (self.env_type, hist_len))

        self.cuda = False if cuda is None else cuda
        self.dtype = dtype = TorchTypes(self.cuda)
        self.rgb = dtype.FT([.2126, .7152, .0722])

        # torch.size([1, 4, 24, 24])
        """
        self.hist_state = torch.FloatTensor(1, hist_len, *state_dims)
        self.hist_state.fill_(0)
        """

        self.d = OrderedDict({i: torch.FloatTensor(1, 1, *state_dims).fill_(0)
                              for i in range(hist_len)})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号