def _encode_observation(self, idx):
end_idx = idx + 1 # make noninclusive
start_idx = end_idx - self.frame_history_len
# this checks if we are using low-dimensional observations, such as RAM
# state, in which case we just directly return the latest RAM.
if len(self.obs.shape) == 2:
return self.obs[end_idx-1]
# if there weren't enough frames ever in the buffer for context
if start_idx < 0 and self.num_in_buffer != self.size:
start_idx = 0
for idx in range(start_idx, end_idx - 1):
if self.done[idx % self.size]:
start_idx = idx + 1
missing_context = self.frame_history_len - (end_idx - start_idx)
# if zero padding is needed for missing context
# or we are on the boundry of the buffer
if start_idx < 0 or missing_context > 0:
frames = [np.zeros_like(self.obs[0]) for _ in range(missing_context)]
for idx in range(start_idx, end_idx):
frames.append(self.obs[idx % self.size])
return np.concatenate(frames, 2)
else:
# this optimization has potential to saves about 30% compute time \o/
img_h, img_w = self.obs.shape[1], self.obs.shape[2]
return self.obs[start_idx:end_idx].transpose(1, 2, 0, 3).reshape(img_h, img_w, -1)
评论列表
文章目录