def _build_policy(env, predictor, epsilon):
eye = np.eye(env.num_states)
q_values = predictor.predict(
{str(i): eye[i]
for i in range(env.num_states)}
)
policy_vector = [
env.ACTIONS[np.argmax([q_values[action][i] for action in env.ACTIONS])]
for i in range(env.num_states)
]
def policy(state) -> str:
if np.random.random() < epsilon:
return np.random.choice(env.ACTIONS)
else:
return policy_vector[state]
return policy
评论列表
文章目录