def get_q_values_all_actions(
self, states: np.ndarray, use_target_network: Optional[bool] = True
) -> np.ndarray:
"""
Takes in a set of states and runs the test Q Network on them.
Creates Q(states, actions), a blob with shape (batch_size, action_dim).
Q(states, actions)[i][j] is an approximation of Q*(states[i], action_j).
Note that action_j takes on every possible action (of which there are
self.action_dim_. Stores blob in self.output_blob and returns its value.
:param states: Numpy array with shape (batch_size, state_dim). Each row
contains a representation of a state.
:param possible_next_actions: Numpy array with shape (batch_size, action_dim).
possible_next_actions[i][j] = 1 iff the agent can take action j from
state i.
:use_target_network: Boolean that indicates whether or not to use this
trainer's TargetNetwork to compute Q values.
"""
if use_target_network:
return self.target_network.target_values(states)
return self.score(states)
评论列表
文章目录