def sample(self, dist_info): samples = self._f_sample(dist_info["prob"]) import ipdb ipdb.set_trace()