categorical.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def sample(self):
        """
        Returns a sample which has the same shape as `ps` (or `vs`), except
        that if ``one_hot=True`` (and no `vs` is specified), the last dimension
        will have the same size as the number of events. The type of the sample
        is `numpy.ndarray` if `vs` is a list or a numpy array, else a tensor
        is returned.

        :return: sample from the Categorical distribution
        :rtype: numpy.ndarray or torch.LongTensor
        """
        sample = torch_multinomial(self.ps.data, 1, replacement=True).expand(*self.shape())
        sample_one_hot = torch_zeros_like(self.ps.data).scatter_(-1, sample, 1)

        if self.vs is not None:
            if isinstance(self.vs, np.ndarray):
                sample_bool_index = sample_one_hot.cpu().numpy().astype(bool)
                return self.vs[sample_bool_index].reshape(*self.shape())
            else:
                return self.vs.masked_select(sample_one_hot.byte())
        if self.one_hot:
            return Variable(sample_one_hot)
        return Variable(sample)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号