def get_random_state():
"""Gets the state of the random number generator for the current device.
If the state for the current device is not created yet, this function
creates a new one, initializes it, and stores it as the state for the
current device.
Returns:
RandomState: The state of the random number generator for the
device.
"""
dev = cuda.Device()
rs = _random_states.get(dev.id, None)
if rs is None:
seed = os.getenv('CUPY_SEED')
if seed is None:
seed = os.getenv('CHAINER_SEED')
if seed is not None:
seed = numpy.uint64(int(seed))
rs = RandomState(seed)
rs = _random_states.setdefault(dev.id, rs)
return rs
评论列表
文章目录