def __init__(self, env, env_type, hist_len, state_dims, cuda=None):
super(PreprocessFrames, self).__init__(env)
self.env_type = env_type
self.state_dims = state_dims
self.hist_len = hist_len
self.env_wh = self.env.observation_space.shape[0:2]
self.env_ch = self.env.observation_space.shape[2]
self.wxh = self.env_wh[0] * self.env_wh[1]
# need to find a better way
if self.env_type == "atari":
self._preprocess = self._atari_preprocess
elif self.env_type == "catch":
self._preprocess = self._catch_preprocess
print("[Preprocess Wrapper] for %s with state history of %d frames."
% (self.env_type, hist_len))
self.cuda = False if cuda is None else cuda
self.dtype = dtype = TorchTypes(self.cuda)
self.rgb = dtype.FT([.2126, .7152, .0722])
# torch.size([1, 4, 24, 24])
"""
self.hist_state = torch.FloatTensor(1, hist_len, *state_dims)
self.hist_state.fill_(0)
"""
self.d = OrderedDict({i: torch.FloatTensor(1, 1, *state_dims).fill_(0)
for i in range(hist_len)})
评论列表
文章目录