gridworld_enum.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:BlueWhale 作者: caffe2 项目源码 文件源码
def preprocess_samples(
        self,
        states: List[Dict[str, float]],
        actions: List[str],
        rewards: List[float],
        next_states: List[Dict[str, float]],
        next_actions: List[str],
        is_terminals: List[bool],
        possible_next_actions: List[List[str]],
        reward_timelines: Optional[List[Dict[int, float]]],
    ) -> TrainingDataPage:
        tdp = self.preprocess_samples_discrete(
            states, actions, rewards, next_states, next_actions, is_terminals,
            possible_next_actions, reward_timelines
        )
        tdp.states = np.where(tdp.states == 1.0)[1].reshape(-1, 1
                                                           ).astype(np.float32)
        tdp.next_states = np.where(tdp.next_states == 1.0)[1].reshape(
            -1, 1
        ).astype(np.float32)
        return tdp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号