def test_specified_rng():
from npdl.utils.random import get_rng
from npdl.utils.random import set_rng
from npdl.initializations import Normal
from npdl.initializations import Uniform
from npdl.initializations import GlorotNormal
from npdl.initializations import GlorotUniform
from numpy.random import RandomState
from numpy import allclose
shape = (10, 20)
seed = 12345
rng = get_rng()
for test_cls in [Normal, Uniform, GlorotNormal, GlorotUniform]:
set_rng(RandomState(seed))
sample1 = test_cls().call(shape)
set_rng(RandomState(seed))
sample2 = test_cls().call(shape)
# reset to original RNG for other tests
set_rng(rng)
assert allclose(sample1, sample2), \
"random initialization was inconsistent " \
"for {}".format(test_cls.__name__)
评论列表
文章目录