gridworld_base.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:BlueWhale 作者: caffe2 项目源码 文件源码
def _compute_optimal(self):
        not_visited = {
            (y, x)
            for x in range(self.width) for y in range(self.height)
        }
        queue = collections.deque()
        queue.append(tuple(j[0] for j in np.where(self.grid == G)))
        policy = np.empty(self.grid.shape, dtype=np.object)
        print("INITIAL POLICY")
        print(policy)
        while len(queue) > 0:
            current = queue.pop()
            if current in not_visited:
                not_visited.remove(current)

            possible_actions = self.possible_next_actions(
                self._index(current), True
            )
            for action in possible_actions:
                self._state = self._index(current)
                next_state, _, _, _ = self.step(action)
                next_state_pos = self._pos(next_state)
                if next_state_pos not in not_visited:
                    continue
                not_visited.remove(next_state_pos)
                if not self.is_terminal(next_state) and \
                        self.grid[next_state_pos] != W:
                    policy[next_state_pos] = self.invert_action(action)
                    queue.appendleft(self._pos(next_state))
        print("FINAL POLICY")
        print(policy)
        return policy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号