agents.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:uct_atari 作者: 5vision 项目源码 文件源码
def __init__(self, model_path, flip_map=None, gray_state=True, **kwargs):
        # load model
        model = load_model(
            model_path,
            custom_objects={'loss_fn': categorical_crossentropy}
        )
        if flip_map is not None:
            assert model.output_shape[1] == len(flip_map)

        super(KerasAgent, self).__init__(n_actions=model.output_shape[1], **kwargs)

        self.gray_state = gray_state

        if len(model.input_shape) == 5:
            self.n_frames = model.input_shape[2]
            self.rnn = True
        else:
            self.n_frames = model.input_shape[1]
            self.rnn = False

        if not gray_state:
            self.n_frames /= 3
        self.height, self.width = model.input_shape[2:]
        self.model = model
        self.flip_map = flip_map
        self.reset()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号