test_initialization.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def test_specified_rng():
    from npdl.utils.random import get_rng
    from npdl.utils.random import set_rng
    from npdl.initializations import Normal
    from npdl.initializations import Uniform
    from npdl.initializations import GlorotNormal
    from npdl.initializations import GlorotUniform

    from numpy.random import RandomState
    from numpy import allclose

    shape = (10, 20)
    seed = 12345
    rng = get_rng()

    for test_cls in [Normal, Uniform, GlorotNormal, GlorotUniform]:
        set_rng(RandomState(seed))
        sample1 = test_cls().call(shape)
        set_rng(RandomState(seed))
        sample2 = test_cls().call(shape)
        # reset to original RNG for other tests
        set_rng(rng)
        assert allclose(sample1, sample2), \
            "random initialization was inconsistent " \
            "for {}".format(test_cls.__name__)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号