keras_dqn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepRL-FlappyBird 作者: hashbangCoder 项目源码 文件源码
def get_targets(mini_batch,target_model):

    # mini_batch format : (input_state,action,reward,output_state,tState,epsilon)
    actions= np.argmax(np.asarray([item[1] for item in mini_batch]),axis=1).astype(int)
    state_inputs = np.concatenate(tuple([exp[3] for exp in mini_batch]),axis=0)
    train_inputs = np.concatenate(tuple([exp[0] for exp in mini_batch]),axis=0)
    est_values = (target_model.predict_on_batch(state_inputs)).max(axis=1)
    target = np.zeros(shape=(len(mini_batch),2))
    for item in range(len(mini_batch)):
        target[item,actions[item]] = mini_batch[item][2] + p.DISCOUNT*est_values[item]*int(not mini_batch[item][-2])
    #target = np.asarray([mini_batch[item][2] + p.DISCOUNT*est_values[item]  if not mini_batch[item][-2] else mini_batch[item][2] for item in range(len(mini_batch))])
    #assert(target.shape[0] == p.batch_size)

    return target, train_inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号