def update_on_policy(self, statevar):
assert self.t_start < self.t
if not self.disable_online_update:
next_values = {}
for t in range(self.t_start + 1, self.t):
next_values[t - 1] = self.past_values[t]
if statevar is None:
next_values[self.t - 1] = chainer.Variable(
self.xp.zeros_like(self.past_values[self.t - 1].data))
else:
with state_kept(self.model):
_, v = self.model(statevar)
next_values[self.t - 1] = v
log_probs = {t: self.past_action_distrib[t].log_prob(
self.xp.asarray(self.xp.expand_dims(a, 0)))
for t, a in self.past_actions.items()}
self.online_batch_losses.append(self.compute_loss(
t_start=self.t_start, t_stop=self.t,
rewards=self.past_rewards,
values=self.past_values,
next_values=next_values,
log_probs=log_probs))
if len(self.online_batch_losses) == self.batchsize:
loss = chainerrl.functions.sum_arrays(
self.online_batch_losses) / self.batchsize
self.update(loss)
self.online_batch_losses = []
self.init_history_data_for_online_update()
评论列表
文章目录