def preprocess_samples(
self,
states: List[Dict[str, float]],
actions: List[Dict[str, float]],
rewards: List[float],
next_states: List[Dict[str, float]],
next_actions: List[Dict[str, float]],
is_terminals: List[bool],
possible_next_actions: List[List[Dict[str, float]]],
reward_timelines: List[Dict[int, float]],
) -> TrainingDataPage:
tdp = GridworldContinuous.preprocess_samples(
self, states, actions, rewards, next_states, next_actions,
is_terminals, possible_next_actions, reward_timelines
)
tdp.states = np.where(tdp.states == 1.0)[1].reshape(-1, 1
).astype(np.float32)
tdp.next_states = np.where(tdp.next_states == 1.0)[1].reshape(
-1, 1
).astype(np.float32)
return tdp
评论列表
文章目录