rl-network-train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Deep-Learning-with-Keras 作者: PacktPublishing 项目源码 文件源码
def get_next_batch(experience, model, num_actions, gamma, batch_size):
    batch_indices = np.random.randint(low=0, high=len(experience),
                                      size=batch_size)
    batch = [experience[i] for i in batch_indices]
    X = np.zeros((batch_size, 80, 80, 4))
    Y = np.zeros((batch_size, num_actions))
    for i in range(len(batch)):
        s_t, a_t, r_t, s_tp1, game_over = batch[i]
        X[i] = s_t
        Y[i] = model.predict(s_t)[0]
        Q_sa = np.max(model.predict(s_tp1)[0])
        if game_over:
            Y[i, a_t] = r_t
        else:
            Y[i, a_t] = r_t + gamma * Q_sa
    return X, Y


############################# main ###############################

# initialize parameters
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号