classifier_tf.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def load_features_episode(self, episode):
        features = {}
        observations = [frame.observation for frame in episode.frames if frame.has_action()]
        features['observation'] = np.concatenate(
            [np.expand_dims(observation, axis=0) for observation in observations], axis=0)
        images = [frame.image for frame in episode.frames if frame.has_action()]
        features['image'] = np.concatenate(
            [process_image(image, self.hparams) for image in images], axis=0)
        actions = [frame.get_proposed_action() for frame in episode.frames if frame.has_action()]
        features['action'] = np.expand_dims(np.array(actions), axis=1)
        features['index'] = np.array(
            [i for i, frame in enumerate(episode.frames) if frame.has_action()])
        return features
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号