bernoulli.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def enumerate_support(self):
        """
        Returns the Bernoulli distribution's support, as a tensor along the first dimension.

        Note that this returns support values of all the batched RVs in lock-step, rather
        than the full cartesian product. To iterate over the cartesian product, you must
        construct univariate Bernoullis and use itertools.product() over all univariate
        variables (may be expensive).

        :return: torch variable enumerating the support of the Bernoulli distribution.
            Each item in the return value, when enumerated along the first dimensions, yields a
            value from the distribution's support which has the same dimension as would be returned by
            sample.
        :rtype: torch.autograd.Variable.
        """
        return Variable(torch.stack([torch.Tensor([t]).expand_as(self.ps) for t in [0, 1]]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号