def logprob(self, data):
logprobs = np.stack(
[server.logprob(data) for server in self._ensemble])
logprobs = logsumexp(logprobs, axis=0)
logprobs -= np.log(len(self._ensemble))
assert logprobs.shape == (data.shape[0], )
return logprobs
评论列表
文章目录