def forward(self, state, action, Reward, state_dash, episode_end):
num_of_batch = state.shape[0]
s = Variable(state)
s_dash = Variable(state_dash)
Q = self.Q_func(s) # Get Q-value
# Generate Target Signals
tmp2 = self.Q_func(s_dash)
tmp2 = list(map(np.argmax, tmp2.data.get())) # argmaxQ(s',a)
tmp = self.Q_func_target(s_dash) # Q'(s',*)
tmp = list(tmp.data.get())
# select Q'(s',*) due to argmaxQ(s',a)
res1 = []
for i in range(num_of_batch):
res1.append(tmp[i][tmp2[i]])
#max_Q_dash = np.asanyarray(tmp, dtype=np.float32)
max_Q_dash = np.asanyarray(res1, dtype=np.float32)
target = np.asanyarray(Q.data.get(), dtype=np.float32)
for i in xrange(num_of_batch):
if not episode_end[i][0]:
tmp_ = np.sign(Reward[i]) + self.gamma * max_Q_dash[i]
else:
tmp_ = np.sign(Reward[i])
action_index = self.action_to_index(action[i])
target[i, action_index] = tmp_
# TD-error clipping
td = Variable(cuda.to_gpu(target)) - Q # TD error
td_tmp = td.data + 1000.0 * (abs(td.data) <= 1) # Avoid zero division
td_clip = td * (abs(td.data) <= 1) + td/abs(td_tmp) * (abs(td.data) > 1)
zero_val = Variable(cuda.to_gpu(np.zeros((self.replay_size, self.num_of_actions), dtype=np.float32)))
loss = F.mean_squared_error(td_clip, zero_val)
return loss, Q
评论列表
文章目录