def _batch2torch(self, batch, batch_size):
""" List of transitions -> Batch of transitions -> pytorch tensors.
Returns:
states: torch.size([batch_size, hist_len, w, h])
a/r/d: torch.size([batch_size, 1])
"""
# check-out pytorch dqn tutorial.
# (t1, t2, ... tn) -> t((s1, s2, ..., sn), (a1, a2, ... an) ...)
batch = BatchTransition(*zip(*batch))
# lists to tensors
state_batch = torch.cat(batch.state, 0).type(self.dtype.FT) / 255
action_batch = self.dtype.LT(batch.action).unsqueeze(1)
reward_batch = self.dtype.FT(batch.reward).unsqueeze(1)
next_state_batch = torch.cat(batch.state_, 0).type(self.dtype.FT) / 255
# [False, False, True, False] -> [1, 1, 0, 1]::ByteTensor
mask = 1 - self.dtype.BT(batch.done).unsqueeze(1)
return [batch_size, state_batch, action_batch, reward_batch,
next_state_batch, mask]
ntuple_experience_replay.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录