def get_max_q_values(
self,
next_states: np.ndarray,
possible_next_actions: Optional[np.ndarray] = None,
use_target_network: Optional[bool] = True
) -> np.ndarray:
q_values = self.get_q_values_all_actions(
next_states, use_target_network
)
if possible_next_actions is not None:
mask = np.multiply(
np.logical_not(possible_next_actions),
self.ACTION_NOT_POSSIBLE_VAL
)
q_values += mask
return np.max(q_values, axis=1, keepdims=True)
评论列表
文章目录