dqn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def update(Q, target_Q, opt, samples, gamma=0.99, target_type='double_dqn'): 
    xp = Q.xp
    s = np.ndarray(shape=(minibatch_size, STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT), dtype=np.float32)
    a = np.asarray([sample[1] for sample in samples], dtype=np.int32)
    r = np.asarray([sample[2] for sample in samples], dtype=np.float32)
    done = np.asarray([sample[3] for sample in samples], dtype=np.float32)
    s_next = np.ndarray(shape=(minibatch_size, STATE_LENGTH, FRAME_WIDTH, FRAME_HEIGHT), dtype=np.float32)

    for i in xrange(minibatch_size):
        s[i] = samples[i][0]
        s_next[i] = samples[i][4]

    # to gpu if available
    s = xp.asarray(s)
    a = xp.asarray(a)
    r = xp.asarray(r)
    done = xp.asarray(done)
    s_next = xp.asarray(s_next)

    # Prediction: Q(s,a)
    y = F.select_item(Q(s), a)

    f0 = Q.conv1.data
    print f0.shape
    # Target: r + gamma * max Q_b (s',b)
    with chainer.no_backprop_mode():
        if target_type == 'dqn':
            t = r + gamma * (1 - done) * F.max(target_Q(s_next), axis=1)
        elif target_type == 'double_dqn':
            t = r + gamma * (1 - done) * F.select_item(
                target_Q(s_next), F.argmax(Q(s_next), axis=1))
        else:
            raise ValueError('Unsupported target_type: {}'.format(target_type))
    loss = mean_clipped_loss(y, t)
    Q.cleargrads()
    loss.backward()
    opt.update()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号