def sample(self, n_samples=1):
mass_function = self.mass_function.data
res = torch.multinomial(mass_function, n_samples, replacement=True)
# Sample dimension is first
if res.ndimension() == 2:
res = res.t()
return res
评论列表
文章目录