def sample(self):
"""
Draws either a single sample (if alpha.dim() == 1), or one sample per param (if alpha.dim() == 2).
(Un-reparameterized).
:param torch.autograd.Variable alpha:
"""
alpha_np = self.alpha.data.cpu().numpy()
if self.alpha.dim() == 1:
x_np = spr.dirichlet.rvs(alpha_np)[0]
else:
x_np = np.empty_like(alpha_np)
for i in range(alpha_np.shape[0]):
x_np[i, :] = spr.dirichlet.rvs(alpha_np[i, :])[0]
x = Variable(type(self.alpha.data)(x_np))
return x
评论列表
文章目录