def __init__(self, environment: EnvironmentInterface, memory: Memory, image_size: int,
random_action_policy: RandomActionPolicy, batch_size: int, discount: float,
should_load_model: bool, should_save: bool, action_type: Any,
create_model: Callable[[Any, int], Model], batches_per_frame: int):
self.environment = environment
self.random_action_policy = random_action_policy
self.memory = memory
self.image_size = image_size
self.batch_size = batch_size
self.discount = discount
self.action_type = action_type
self.should_save = should_save
self.should_exit = False
self.default_sigint_handler = signal.getsignal(signal.SIGINT)
self.training_info = TrainingInfo(should_load_model)
self.mean_training_time = RunningAverage(1000, self.training_info['mean_training_time'])
if batches_per_frame:
self.training_info['batches_per_frame'] = batches_per_frame
if should_load_model and Path(self.MODEL_PATH).is_file():
self.model = load_model(self.MODEL_PATH)
else:
self.model = create_model((self.image_size, self.image_size, StateAssembler.FRAME_COUNT),
action_type.COUNT)
评论列表
文章目录