dqn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:SerpentAI 作者: SerpentAI 项目源码 文件源码
def __init__(
        self,
        input_shape=None,
        input_mapping=None,
        replay_memory_size=10000,
        batch_size=32,
        action_space=None,
        max_steps=1000000,
        observe_steps=None,
        initial_epsilon=1.0,
        final_epsilon=0.1,
        gamma=0.99,
        model_file_path=None,
        model_learning_rate=2.5e-4,
        override_epsilon=False
    ):
        self.type = "DQN"
        self.input_shape = input_shape
        self.replay_memory = ReplayMemory(memory_size=replay_memory_size)
        self.batch_size = batch_size
        self.action_space = action_space
        self.action_count = len(self.action_space.combinations)
        self.action_input_mapping = self._generate_action_space_combination_input_mapping(input_mapping)
        self.frame_stack = None
        self.max_steps = max_steps
        self.observe_steps = observe_steps or (0.1 * replay_memory_size)
        self.current_observe_step = 0
        self.current_step = 0
        self.initial_epsilon = initial_epsilon
        self.final_epsilon = final_epsilon
        self.previous_epsilon = initial_epsilon
        self.epsilon_greedy_q_policy = EpsilonGreedyQPolicy(
            initial_epsilon=self.initial_epsilon,
            final_epsilon=self.final_epsilon,
            max_steps=self.max_steps
        )
        self.gamma = gamma
        self.current_action = None
        self.current_action_index = None
        self.current_action_type = None
        self.first_run = True
        self.mode = "OBSERVE"

        self.model_learning_rate = model_learning_rate
        self.model = self._initialize_model()

        if model_file_path is not None:
            self.load_model_weights(model_file_path, override_epsilon)

        self.model_loss = 0

        #self.keras_callbacks = list()
        #self.keras_callbacks.append(TensorBoard(log_dir='/tmp/logs', histogram_freq=0, batch_size=32, write_graph=True, write_grads=True, write_images=True, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None))

        self.visual_debugger = VisualDebugger()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号