def tdma(a, b, c, d, workgrp_size=None):
assert a.shape == b.shape == c.shape == d.shape
assert a.dtype == b.dtype == c.dtype == d.dtype
# Check that PyOpenCL is installed and that the Bohrium runtime uses the OpenCL backend
if not bh.interop_pyopencl.available():
raise NotImplementedError("OpenCL not available")
# Get the OpenCL context from Bohrium
ctx = bh.interop_pyopencl.get_context()
queue = cl.CommandQueue(ctx)
ret = bh.empty(a.shape, dtype=a.dtype)
a_buf, b_buf, c_buf, d_buf, ret_buf = map(bh.interop_pyopencl.get_buffer, (a, b, c, d, ret))
prg = compile_tdma(ret.shape[-1], bh.interop_pyopencl.type_np2opencl_str(a.dtype))
global_size = functools.reduce(operator.mul, ret.shape[:-1])
prg.tdma(queue, [global_size], workgrp_size, a_buf, b_buf, c_buf, d_buf, ret_buf)
return ret
评论列表
文章目录