test_multinomial.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_n_samples_compatibility():
    """
    This test checks if the new change to MultinomialFromUniform is still compatible
    with old interface. Here I will load a graph created (using the old interface) as follows:
    RandomStreams = theano.sandbox.rng_mrg.MRG_RandomStreams
    th_rng = RandomStreams(12345)
    X = T.matrix('X')
    pvals = T.exp(X)
    pvals = pvals / pvals.sum(axis=1, keepdims=True)
    samples = th_rng.multinomial(pvals=pvals)
    pickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
    """
    folder = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(folder, "multinomial_test_graph.pkl"),
              "rb") as pkl_file:
        if PY3:
            u = CompatUnpickler(pkl_file, encoding="latin1")
        else:
            u = CompatUnpickler(pkl_file)
        try:
            X, samples = u.load()
        except ImportError:
            # Windows sometimes fail with nonsensical errors like:
            #   ImportError: No module named type
            #   ImportError: No module named copy_reg
            # when "type" and "copy_reg" are builtin modules.
            if sys.platform == 'win32':
                exc_type, exc_value, exc_trace = sys.exc_info()
                reraise(SkipTest, exc_value, exc_trace)
            raise

        f = theano.function([X], samples)
        res = f(numpy.random.randn(20, 10))
        assert numpy.all(res.sum(axis=1) == 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号