def enumerate_support(self):
"""
Returns the categorical distribution's support, as a tensor along the first dimension.
Note that this returns support values of all the batched RVs in lock-step, rather
than the full cartesian product. To iterate over the cartesian product, you must
construct univariate Categoricals and use itertools.product() over all univariate
variables (but this is very expensive).
:param ps: Tensor where the last dimension denotes the event probabilities, *p_k*,
which must sum to 1. The remaining dimensions are considered batch dimensions.
:type ps: torch.autograd.Variable
:param vs: Optional parameter, enumerating the items in the support. This could either
have a numeric or string type. This should have the same dimension as ``ps``.
:type vs: list or numpy.ndarray or torch.autograd.Variable
:param one_hot: Denotes whether one hot encoding is enabled. This is True by default.
When set to false, and no explicit `vs` is provided, the last dimension gives
the one-hot encoded value from the support.
:type one_hot: boolean
:return: Torch variable or numpy array enumerating the support of the categorical distribution.
Each item in the return value, when enumerated along the first dimensions, yields a
value from the distribution's support which has the same dimension as would be returned by
sample. If ``one_hot=True``, the last dimension is used for the one-hot encoding.
:rtype: torch.autograd.Variable or numpy.ndarray.
"""
sample_shape = self.batch_shape() + (1,)
support_samples_size = (self.event_shape()) + sample_shape
vs = self.vs
if vs is not None:
if isinstance(vs, np.ndarray):
return vs.transpose().reshape(*support_samples_size)
else:
return torch.transpose(vs, 0, -1).contiguous().view(support_samples_size)
if self.one_hot:
return Variable(torch.stack([t.expand_as(self.ps) for t in torch_eye(*self.event_shape())]))
else:
LongTensor = torch.cuda.LongTensor if self.ps.is_cuda else torch.LongTensor
return Variable(
torch.stack([LongTensor([t]).expand(sample_shape)
for t in torch.arange(0, *self.event_shape()).long()]))
评论列表
文章目录