def compute_actor_loss(self, batch):
"""Compute loss for actor.
Preconditions:
q_function must have seen up to s_{t-1} and s_{t-1}.
policy must have seen up to s_{t-1}.
Preconditions:
q_function must have seen up to s_t and s_t.
policy must have seen up to s_t.
"""
batch_state = batch['state']
batch_action = batch['action']
batch_size = len(batch_action)
# Estimated policy observes s_t
onpolicy_actions = self.policy(batch_state).sample()
# Q(s_t, mu(s_t)) is evaluated.
# This should not affect the internal state of Q.
with state_kept(self.q_function):
q = self.q_function(batch_state, onpolicy_actions)
# Estimated Q-function observes s_t and a_t
if isinstance(self.q_function, Recurrent):
self.q_function.update_state(batch_state, batch_action)
# Avoid the numpy #9165 bug (see also: chainer #2744)
q = q[:, :]
# Since we want to maximize Q, loss is negation of Q
loss = - F.sum(q) / batch_size
# Update stats
self.average_actor_loss *= self.average_loss_decay
self.average_actor_loss += ((1 - self.average_loss_decay) *
float(loss.data))
return loss
评论列表
文章目录