def _learn_from_memories(self, replay_memories, q_network, global_step):
if self._pre_learning_stage(global_step):
loss = 0.0
return loss
sampled_replay_memories = replay_memories.sample(sample_size=self.hyperparameters.REPLAY_MEMORIES_TRAIN_SAMPLE_SIZE,
recent_memories_span=self.hyperparameters.REPLAY_MEMORIES_RECENT_SAMPLE_SPAN)
consequent_states = [replay_memory.consequent_state for replay_memory in sampled_replay_memories]
max_q_consequent_states = np.nanmax(q_network.forward_pass_batched(consequent_states), axis=1)
train_bundles = [None] * self.hyperparameters.REPLAY_MEMORIES_TRAIN_SAMPLE_SIZE
discount_factor = self.hyperparameters.Q_UPDATE_DISCOUNT_FACTOR
for idx, replay_memory in enumerate(sampled_replay_memories):
target_action_q_value = float(self._q_target(replay_memory=replay_memory,
max_q_consequent_state=max_q_consequent_states[idx],
discount_factor=discount_factor))
train_bundles[idx] = q_network.create_train_bundle(state=replay_memory.initial_state,
action_index=replay_memory.action_index,
target_action_q_value=target_action_q_value)
loss = q_network.train(train_bundles, global_step - self.hyperparameters.REPLAY_MEMORIES_MINIMUM_SIZE_FOR_LEARNING)
return loss
评论列表
文章目录