def calc_loss_recurrent(self, frames, actions, rewards, done_list, size_list):
# TODO self.max_step -> max_step
s = Variable(frames.astype(np.float32))
self.model_target.reset_state() # Refresh model_target's state
self.model_target.q_function(s[0]) # Update target model initial state
target_q = self.xp.zeros((self.max_step, self.replay_batch_size), dtype=np.float32)
selected_q_tuple = [None for _ in range(self.max_step)]
for frame in range(0, self.max_step):
q = self.model.q_function(s[frame])
q_dash = self.model_target.q_function(s[frame+1]) # Q(s',*): shape is (batch_size, action_num)
max_q_dash = q_dash.data.max(axis=1) # max_a Q(s',a): shape is (batch_size,)
if self.clipping:
rs = self.xp.sign(rewards[frame])
else:
rs = rewards[frame]
target_q[frame] = rs + self.xp.logical_not(done_list[frame]).astype(np.int)*(self.gamma*max_q_dash)
selected_q_tuple[frame] = F.select_item(q, actions[frame].astype(np.int))
enable = self.xp.broadcast_to(self.xp.arange(self.max_step), (self.replay_batch_size, self.max_step))
size_list = self.xp.expand_dims(cuda.to_gpu(size_list), -1)
enable = (enable < size_list).T
selected_q = F.concat(selected_q_tuple, axis=0)
# element-wise huber loss
huber_loss = F.huber_loss(
F.expand_dims(F.flatten(target_q), axis=1),
F.expand_dims(selected_q, axis=1), delta=1.0)
huber_loss = F.reshape(huber_loss, enable.shape)
zeros = self.xp.zeros(enable.shape, dtype=np.float32)
loss = F.sum(F.where(enable, huber_loss, zeros)) #/ self.replay_batch_size
#print("loss", loss.data)
return loss
评论列表
文章目录