def feed_forward_generator(self, advantages, num_mini_batch):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
for indices in sampler:
indices = torch.LongTensor(indices)
if advantages.is_cuda:
indices = indices.cuda()
observations_batch = self.observations[:-1].view(-1,
*self.observations.size()[2:])[indices]
states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
adv_targ = advantages.view(-1, 1)[indices]
yield observations_batch, states_batch, actions_batch, \
return_batch, masks_batch, old_action_log_probs_batch, adv_targ
评论列表
文章目录