dirichlet.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def sample(self):
        """
        Draws either a single sample (if alpha.dim() == 1), or one sample per param (if alpha.dim() == 2).

        (Un-reparameterized).

        :param torch.autograd.Variable alpha:
        """
        alpha_np = self.alpha.data.cpu().numpy()
        if self.alpha.dim() == 1:
            x_np = spr.dirichlet.rvs(alpha_np)[0]
        else:
            x_np = np.empty_like(alpha_np)
            for i in range(alpha_np.shape[0]):
                x_np[i, :] = spr.dirichlet.rvs(alpha_np[i, :])[0]
        x = Variable(type(self.alpha.data)(x_np))
        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号