def test_select_distinct(self):
"""
Tests that multinomial_wo_replacement always selects distinct elements
"""
th_rng = RandomStreams(12345)
p = tensor.fmatrix()
n = tensor.iscalar()
m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True)
n_elements = 1000
all_indices = range(n_elements)
numpy.random.seed(12345)
for i in [5, 10, 50, 100, 500, n_elements]:
pvals = numpy.random.randint(1, 100, (1, n_elements)).astype(config.floatX)
pvals /= pvals.sum(1)
res = f(pvals, i)
res = numpy.squeeze(res)
assert len(res) == i
assert numpy.all(numpy.in1d(numpy.unique(res), all_indices)), res
评论列表
文章目录