categorical.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def enumerate_support(self):
        """
        Returns the categorical distribution's support, as a tensor along the first dimension.

        Note that this returns support values of all the batched RVs in lock-step, rather
        than the full cartesian product. To iterate over the cartesian product, you must
        construct univariate Categoricals and use itertools.product() over all univariate
        variables (but this is very expensive).

        :param ps: Tensor where the last dimension denotes the event probabilities, *p_k*,
            which must sum to 1. The remaining dimensions are considered batch dimensions.
        :type ps: torch.autograd.Variable
        :param vs: Optional parameter, enumerating the items in the support. This could either
            have a numeric or string type. This should have the same dimension as ``ps``.
        :type vs: list or numpy.ndarray or torch.autograd.Variable
        :param one_hot: Denotes whether one hot encoding is enabled. This is True by default.
            When set to false, and no explicit `vs` is provided, the last dimension gives
            the one-hot encoded value from the support.
        :type one_hot: boolean
        :return: Torch variable or numpy array enumerating the support of the categorical distribution.
            Each item in the return value, when enumerated along the first dimensions, yields a
            value from the distribution's support which has the same dimension as would be returned by
            sample. If ``one_hot=True``, the last dimension is used for the one-hot encoding.
        :rtype: torch.autograd.Variable or numpy.ndarray.
        """
        sample_shape = self.batch_shape() + (1,)
        support_samples_size = (self.event_shape()) + sample_shape
        vs = self.vs

        if vs is not None:
            if isinstance(vs, np.ndarray):
                return vs.transpose().reshape(*support_samples_size)
            else:
                return torch.transpose(vs, 0, -1).contiguous().view(support_samples_size)
        if self.one_hot:
            return Variable(torch.stack([t.expand_as(self.ps) for t in torch_eye(*self.event_shape())]))
        else:
            LongTensor = torch.cuda.LongTensor if self.ps.is_cuda else torch.LongTensor
            return Variable(
                torch.stack([LongTensor([t]).expand(sample_shape)
                             for t in torch.arange(0, *self.event_shape()).long()]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号