def update(Q, target_Q, opt, samples, gamma=0.99, target_type='double_dqn'):
xp = Q.xp
s = np.ndarray(shape=(minibatch_size, STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT), dtype=np.float32)
a = np.asarray([sample[1] for sample in samples], dtype=np.int32)
r = np.asarray([sample[2] for sample in samples], dtype=np.float32)
done = np.asarray([sample[3] for sample in samples], dtype=np.float32)
s_next = np.ndarray(shape=(minibatch_size, STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT), dtype=np.float32)
for i in xrange(minibatch_size):
s[i] = samples[i][0]
s_next[i] = samples[i][4]
# to gpu if available
s = xp.asarray(s)
a = xp.asarray(a)
r = xp.asarray(r)
done = xp.asarray(done)
s_next = xp.asarray(s_next)
# Prediction: Q(s,a)
y = F.select_item(Q(s), a)
# Target: r + gamma * max Q_b (s',b)
with chainer.no_backprop_mode():
if target_type == 'dqn':
t = r + gamma * (1 - done) * F.max(target_Q(s_next), axis=1)
elif target_type == 'double_dqn':
t = r + gamma * (1 - done) * F.select_item(
target_Q(s_next), F.argmax(Q(s_next), axis=1))
else:
raise ValueError('Unsupported target_type: {}'.format(target_type))
loss = mean_clipped_loss(y, t)
Q.cleargrads()
loss.backward()
opt.update()
评论列表
文章目录