test_categorical.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def setUp(self):
        n = 1
        self.ps = Variable(torch.Tensor([0.1, 0.6, 0.3]))
        self.batch_ps = Variable(torch.Tensor([[0.1, 0.6, 0.3], [0.2, 0.4, 0.4]]))
        self.n = Variable(torch.Tensor([n]))
        self.test_data = Variable(torch.Tensor([0, 1, 0]))
        self.test_data_nhot = Variable(torch.Tensor([2]))
        self.analytic_mean = n * self.ps
        one = Variable(torch.ones(3))
        self.analytic_var = n * torch.mul(self.ps, one.sub(self.ps))

        # Discrete Distribution
        self.d_ps = Variable(torch.Tensor([[0.2, 0.3, 0.5], [0.1, 0.1, 0.8]]))
        self.d_vs = Variable(torch.Tensor([[0, 1, 2], [3, 4, 5]]))
        self.d_vs_arr = [['a', 'b', 'c'], ['d', 'e', 'f']]
        self.d_vs_tup = (('a', 'b', 'c'), ('d', 'e', 'f'))
        self.d_test_data = Variable(torch.Tensor([[0], [5]]))
        self.d_v_test_data = [['a'], ['f']]

        self.n_samples = 50000

        self.support_one_hot_non_vec = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        self.support_one_hot = torch.Tensor([[[1, 0, 0], [1, 0, 0]],
                                             [[0, 1, 0], [0, 1, 0]],
                                             [[0, 0, 1], [0, 0, 1]]])
        self.support_non_vec = torch.LongTensor([[0], [1], [2]])
        self.support = torch.LongTensor([[[0], [0]], [[1], [1]], [[2], [2]]])
        self.discrete_support_non_vec = torch.Tensor([[0], [1], [2]])
        self.discrete_support = torch.Tensor([[[0], [3]], [[1], [4]], [[2], [5]]])
        self.discrete_arr_support_non_vec = [['a'], ['b'], ['c']]
        self.discrete_arr_support = [[['a'], ['d']], [['b'], ['e']], [['c'], ['f']]]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号