def __init__(self, model_path, flip_map=None, gray_state=True, **kwargs):
# load model
model = load_model(
model_path,
custom_objects={'loss_fn': categorical_crossentropy}
)
if flip_map is not None:
assert model.output_shape[1] == len(flip_map)
super(KerasAgent, self).__init__(n_actions=model.output_shape[1], **kwargs)
self.gray_state = gray_state
if len(model.input_shape) == 5:
self.n_frames = model.input_shape[2]
self.rnn = True
else:
self.n_frames = model.input_shape[1]
self.rnn = False
if not gray_state:
self.n_frames /= 3
self.height, self.width = model.input_shape[2:]
self.model = model
self.flip_map = flip_map
self.reset()
评论列表
文章目录