test_torch.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def test_multinomial(self):
        # with replacement
        n_row = 3
        for n_col in range(4, 5 + 1):
            prob_dist = torch.rand(n_row, n_col)
            prob_dist.select(1, n_col - 1).fill_(0)  # index n_col shouldn't be sampled
            n_sample = n_col
            sample_indices = torch.multinomial(prob_dist, n_sample, True)
            self.assertEqual(prob_dist.dim(), 2)
            self.assertEqual(sample_indices.size(1), n_sample)
            for index in product(range(n_row), range(n_sample)):
                self.assertNotEqual(sample_indices[index], n_col, "sampled an index with zero probability")

        # without replacement
        n_row = 3
        for n_col in range(4, 5 + 1):
            prob_dist = torch.rand(n_row, n_col)
            prob_dist.select(1, n_col - 1).fill_(0)  # index n_col shouldn't be sampled
            n_sample = 3
            sample_indices = torch.multinomial(prob_dist, n_sample, False)
            self.assertEqual(prob_dist.dim(), 2)
            self.assertEqual(sample_indices.size(1), n_sample)
            for i in range(n_row):
                row_samples = {}
                for j in range(n_sample):
                    sample_idx = sample_indices[i, j]
                    self.assertNotEqual(sample_idx, n_col - 1,
                                        "sampled an index with zero probability")
                    self.assertNotIn(sample_idx, row_samples, "sampled an index twice")
                    row_samples[sample_idx] = True

        # vector
        n_col = 4
        prob_dist = torch.rand(n_col)
        n_sample = n_col
        sample_indices = torch.multinomial(prob_dist, n_sample, True)
        s_dim = sample_indices.dim()
        self.assertEqual(sample_indices.dim(), 1, "wrong number of dimensions")
        self.assertEqual(prob_dist.dim(), 1, "wrong number of prob_dist dimensions")
        self.assertEqual(sample_indices.size(0), n_sample, "wrong number of samples")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号