def update(self):
xp = self.xp
if self.standardize_advantages:
all_advs = xp.array([b['adv'] for b in self.memory])
mean_advs = xp.mean(all_advs)
std_advs = xp.std(all_advs)
target_model = copy.deepcopy(self.model)
dataset_iter = chainer.iterators.SerialIterator(
self.memory, self.minibatch_size)
dataset_iter.reset()
while dataset_iter.epoch < self.epochs:
batch = dataset_iter.__next__()
states = batch_states([b['state'] for b in batch], xp, self.phi)
actions = xp.array([b['action'] for b in batch])
distribs, vs_pred = self.model(states)
with chainer.no_backprop_mode():
target_distribs, _ = target_model(states)
advs = xp.array([b['adv'] for b in batch], dtype=xp.float32)
if self.standardize_advantages:
advs = (advs - mean_advs) / std_advs
self.optimizer.update(
self._lossfun,
distribs, vs_pred, distribs.log_prob(actions),
vs_pred_old=xp.array(
[b['v_pred'] for b in batch], dtype=xp.float32),
target_log_probs=target_distribs.log_prob(actions),
advs=advs,
vs_teacher=xp.array(
[b['v_teacher'] for b in batch], dtype=xp.float32),
)
评论列表
文章目录