def enumerate_support(self):
"""
Returns the Bernoulli distribution's support, as a tensor along the first dimension.
Note that this returns support values of all the batched RVs in lock-step, rather
than the full cartesian product. To iterate over the cartesian product, you must
construct univariate Bernoullis and use itertools.product() over all univariate
variables (may be expensive).
:return: torch variable enumerating the support of the Bernoulli distribution.
Each item in the return value, when enumerated along the first dimensions, yields a
value from the distribution's support which has the same dimension as would be returned by
sample.
:rtype: torch.autograd.Variable.
"""
return Variable(torch.stack([torch.Tensor([t]).expand_as(self.ps) for t in [0, 1]]))
评论列表
文章目录