curand.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:neurodriver 作者: neurokernel 项目源码 文件源码
def get_curand_int_func():
    code = """
#include "curand_kernel.h"
extern "C" {
__global__ void 
rand_setup(curandStateXORWOW_t* state, int size, unsigned long long seed)
{
    int tid = threadIdx.x + blockIdx.x * blockDim.x;
    int total_threads = blockDim.x * gridDim.x;

    for(int i = tid; i < size; i+=total_threads)
    {
        curand_init(seed, i, 0, &state[i]);
    }
}
}
    """
    mod = SourceModule(code, no_extern_c = True)
    func = mod.get_function("rand_setup")
    func.prepare('PiL')#[np.intp, np.int32, np.uint64])
    return func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号