def get_e_qval_sym(self, obs_var, policy, **kwargs):
if isinstance(policy, StochasticPolicy):
agent_info = policy.dist_info_sym(obs_var)
mu, log_std = agent_info['mean'], agent_info["log_std"]
std = tf.matrix_diag(tf.exp(log_std))
L_var, V_var, mu_var = self.get_output_sym(obs_var, **kwargs)
L_mat_var = self.get_L_sym(L_var)
P_var = self.get_P_sym(L_mat_var)
A_var = self.get_e_A_sym(P_var, mu_var, mu, std)
qvals = A_var + V_var
else:
mu = policy.get_action_sym(obs_var)
qvals = self.get_qval_sym(obs_var, mu, **kwargs)
return qvals
评论列表
文章目录