def get_curand_int_func():
code = """
#include "curand_kernel.h"
extern "C" {
__global__ void
rand_setup(curandStateXORWOW_t* state, int size, unsigned long long seed)
{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int total_threads = blockDim.x * gridDim.x;
for(int i = tid; i < size; i+=total_threads)
{
curand_init(seed, i, 0, &state[i]);
}
}
}
"""
mod = SourceModule(code, no_extern_c = True)
func = mod.get_function("rand_setup")
func.prepare('PiL')#[np.intp, np.int32, np.uint64])
return func
评论列表
文章目录