classifier_tf.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def predict_episodes(self, model, episode_paths, n=None, out_dir=None, prefix="model/"):
        if n is not None:
            episode_paths = np.random.choice(episode_paths, n, replace=False)
        if out_dir is not None:
            os.makedirs(out_dir, exist_ok=True)
        for ep, episode_path in enumerate(episode_paths):
            episode = frame.load_episode(episode_path)
            features = self.load_features_episode(episode)
            prediction = model.predict_proba(features)
            for i in range(len(prediction)):
                episode.frames[i].info[prefix + "score"] = prediction[i]
                episode.frames[i].info[prefix + "label"] = model.apply_threshold(prediction[i])
            out_path = episode_path
            if out_dir is not None:
                out_path = os.path.join(out_dir, "{}.pkl.gz".format(ep))
            frame.save_episode(out_path, episode)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号