def _generate_normal(self, func, size, dtype, *args):
# curand functions below don't support odd size.
# * curand.generateNormal
# * curand.generateNormalDouble
# * curand.generateLogNormal
# * curand.generateLogNormalDouble
size = core.get_size(size)
element_size = six.moves.reduce(operator.mul, size, 1)
if element_size % 2 == 0:
out = cupy.empty(size, dtype=dtype)
func(self._generator, out.data.ptr, out.size, *args)
return out
else:
out = cupy.empty((element_size + 1,), dtype=dtype)
func(self._generator, out.data.ptr, out.size, *args)
return out[:element_size].reshape(size)
# NumPy compatible functions
评论列表
文章目录