def get_targets(mini_batch,target_model):
# mini_batch format : (input_state,action,reward,output_state,tState,epsilon)
actions= np.argmax(np.asarray([item[1] for item in mini_batch]),axis=1).astype(int)
state_inputs = np.concatenate(tuple([exp[3] for exp in mini_batch]),axis=0)
train_inputs = np.concatenate(tuple([exp[0] for exp in mini_batch]),axis=0)
est_values = (target_model.predict_on_batch(state_inputs)).max(axis=1)
target = np.zeros(shape=(len(mini_batch),2))
for item in range(len(mini_batch)):
target[item,actions[item]] = mini_batch[item][2] + p.DISCOUNT*est_values[item]*int(not mini_batch[item][-2])
#target = np.asarray([mini_batch[item][2] + p.DISCOUNT*est_values[item] if not mini_batch[item][-2] else mini_batch[item][2] for item in range(len(mini_batch))])
#assert(target.shape[0] == p.batch_size)
return target, train_inputs
评论列表
文章目录