def calc_loss(self, state, state_dash, actions, rewards, done_list):
assert(state.shape == state_dash.shape)
s = state.reshape((state.shape[0], reduce(lambda x, y: x*y, state.shape[1:]))).astype(np.float32)
s_dash = state_dash.reshape((state.shape[0], reduce(lambda x, y: x*y, state.shape[1:]))).astype(np.float32)
q = self.model.q_function(s)
q_dash = self.model_target.q_function(s_dash) # Q(s',*)
max_q_dash = np.asarray(list(map(np.max, q_dash.data)), dtype=np.float32) # max_a Q(s',a)
target = q.data.copy()
for i in range(self.replay_batch_size):
assert(self.replay_batch_size == len(done_list))
r = np.sign(rewards[i]) if self.clipping else rewards[i]
if done_list[i]:
discounted_sum = r
else:
discounted_sum = r + self.gamma * max_q_dash[i]
assert(self.replay_batch_size == len(actions))
target[i, actions[i]] = discounted_sum
loss = F.sum(F.huber_loss(Variable(target), q, delta=1.0)) #/ self.replay_batch_size
return loss, q
评论列表
文章目录