test_limited_action_gridworld.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:BlueWhale 作者: caffe2 项目源码 文件源码
def _build_policy(env, predictor, epsilon):
    eye = np.eye(env.num_states)
    q_values = predictor.predict(
        {str(i): eye[i]
         for i in range(env.num_states)}
    )
    policy_vector = [
        env.ACTIONS[np.argmax([q_values[action][i] for action in env.ACTIONS])]
        for i in range(env.num_states)
    ]

    def policy(state) -> str:
        if np.random.random() < epsilon:
            return np.random.choice(env.ACTIONS)
        else:
            return policy_vector[state]

    return policy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号