episode.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:agent-trainer 作者: lopespm 项目源码 文件源码
def _learn_from_memories(self, replay_memories, q_network, global_step):
        if self._pre_learning_stage(global_step):
            loss = 0.0
            return loss

        sampled_replay_memories = replay_memories.sample(sample_size=self.hyperparameters.REPLAY_MEMORIES_TRAIN_SAMPLE_SIZE,
                                                         recent_memories_span=self.hyperparameters.REPLAY_MEMORIES_RECENT_SAMPLE_SPAN)
        consequent_states = [replay_memory.consequent_state for replay_memory in sampled_replay_memories]
        max_q_consequent_states = np.nanmax(q_network.forward_pass_batched(consequent_states), axis=1)

        train_bundles = [None] * self.hyperparameters.REPLAY_MEMORIES_TRAIN_SAMPLE_SIZE
        discount_factor = self.hyperparameters.Q_UPDATE_DISCOUNT_FACTOR
        for idx, replay_memory in enumerate(sampled_replay_memories):
            target_action_q_value = float(self._q_target(replay_memory=replay_memory,
                                                         max_q_consequent_state=max_q_consequent_states[idx],
                                                         discount_factor=discount_factor))
            train_bundles[idx] = q_network.create_train_bundle(state=replay_memory.initial_state,
                                                               action_index=replay_memory.action_index,
                                                               target_action_q_value=target_action_q_value)

        loss = q_network.train(train_bundles, global_step - self.hyperparameters.REPLAY_MEMORIES_MINIMUM_SIZE_FOR_LEARNING)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号