def test_n_samples_compatibility():
"""
This test checks if the new change to MultinomialFromUniform is still compatible
with old interface. Here I will load a graph created (using the old interface) as follows:
RandomStreams = theano.sandbox.rng_mrg.MRG_RandomStreams
th_rng = RandomStreams(12345)
X = T.matrix('X')
pvals = T.exp(X)
pvals = pvals / pvals.sum(axis=1, keepdims=True)
samples = th_rng.multinomial(pvals=pvals)
pickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
"""
folder = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(folder, "multinomial_test_graph.pkl"),
"rb") as pkl_file:
if PY3:
u = CompatUnpickler(pkl_file, encoding="latin1")
else:
u = CompatUnpickler(pkl_file)
try:
X, samples = u.load()
except ImportError:
# Windows sometimes fail with nonsensical errors like:
# ImportError: No module named type
# ImportError: No module named copy_reg
# when "type" and "copy_reg" are builtin modules.
if sys.platform == 'win32':
exc_type, exc_value, exc_trace = sys.exc_info()
reraise(SkipTest, exc_value, exc_trace)
raise
f = theano.function([X], samples)
res = f(numpy.random.randn(20, 10))
assert numpy.all(res.sum(axis=1) == 1)
评论列表
文章目录