def sample_tree(self, num_samples):
size = len(self._ensemble)
pvals = np.ones(size, dtype=np.float32) / size
sub_nums = np.random.multinomial(num_samples, pvals)
samples = []
for server, sub_num in zip(self._ensemble, sub_nums):
samples += server.sample_tree(sub_num)
np.random.shuffle(samples)
assert len(samples) == num_samples
return samples
评论列表
文章目录