def get_next_batch(experience, model, num_actions, gamma, batch_size):
batch_indices = np.random.randint(low=0, high=len(experience),
size=batch_size)
batch = [experience[i] for i in batch_indices]
X = np.zeros((batch_size, 80, 80, 4))
Y = np.zeros((batch_size, num_actions))
for i in range(len(batch)):
s_t, a_t, r_t, s_tp1, game_over = batch[i]
X[i] = s_t
Y[i] = model.predict(s_t)[0]
Q_sa = np.max(model.predict(s_tp1)[0])
if game_over:
Y[i, a_t] = r_t
else:
Y[i, a_t] = r_t + gamma * Q_sa
return X, Y
############################# main ###############################
# initialize parameters
rl-network-train.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录