def __call__(self, obs): action_distrib = self.pi(obs) action_value = self.q(obs) v = F.sum(action_distrib.all_prob * action_value.q_values, axis=1) return action_distrib, action_value, v