test_multinomial.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_select_distinct(self):
        """
        Tests that multinomial_wo_replacement always selects distinct elements
        """
        th_rng = RandomStreams(12345)

        p = tensor.fmatrix()
        n = tensor.iscalar()
        m = th_rng.multinomial_wo_replacement(pvals=p, n=n)

        f = function([p, n], m, allow_input_downcast=True)

        n_elements = 1000
        all_indices = range(n_elements)
        numpy.random.seed(12345)
        for i in [5, 10, 50, 100, 500, n_elements]:
            pvals = numpy.random.randint(1, 100, (1, n_elements)).astype(config.floatX)
            pvals /= pvals.sum(1)
            res = f(pvals, i)
            res = numpy.squeeze(res)
            assert len(res) == i
            assert numpy.all(numpy.in1d(numpy.unique(res), all_indices)), res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号