gputools.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:slitSpectrographBlind 作者: aasensio 项目源码 文件源码
def project_on_basis_gpu(fs_gpu, basis_gpu):

  basis_length = basis_gpu.shape[0]
  shape = np.array(fs_gpu.shape).astype(np.uint32)
  dtype = fs_gpu.dtype
  block_size = (16,16,1)
  grid_size = (1,int(np.ceil(float(basis_length)/block_size[1])))

  weights_gpu = cua.empty(basis_length, dtype=dtype)

  preproc = _generate_preproc(dtype, shape)
  preproc += '#define BLOCK_SIZE %d\n' % (block_size[0]*block_size[1])
  mod = SourceModule(preproc + projection_code, keep=True)

  projection_fun = mod.get_function("projection")

  projection_fun(weights_gpu.gpudata, fs_gpu.gpudata, basis_gpu.gpudata,
                 np.uint32(basis_length),
                 block=block_size, grid=grid_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号