def preprocess_samples(
self,
states: List[Dict[str, float]],
actions: List[str],
rewards: List[float],
next_states: List[Dict[str, float]],
next_actions: List[str],
is_terminals: List[bool],
possible_next_actions: List[List[str]],
reward_timelines: Optional[List[Dict[int, float]]],
) -> TrainingDataPage:
tdp = self.preprocess_samples_discrete(
states, actions, rewards, next_states, next_actions, is_terminals,
possible_next_actions, reward_timelines
)
tdp.states = np.where(tdp.states == 1.0)[1].reshape(-1, 1
).astype(np.float32)
tdp.next_states = np.where(tdp.next_states == 1.0)[1].reshape(
-1, 1
).astype(np.float32)
return tdp
评论列表
文章目录