def make_node(self, pvals, unis, n):
assert pvals.dtype == 'float32'
assert unis.dtype == 'float32'
ctx_name = infer_context_name(pvals, unis)
pvals = as_gpuarray_variable(pvals, ctx_name)
unis = as_gpuarray_variable(unis, ctx_name)
if pvals.ndim != 2:
raise NotImplementedError('pvals ndim should be 2', pvals.ndim)
if unis.ndim != 1:
raise NotImplementedError('unis ndim should be 1', unis.ndim)
if self.odtype == 'auto':
odtype = 'int64'
else:
odtype = self.odtype
assert odtype == 'int64', odtype
br = (pvals.broadcastable[1], pvals.broadcastable[0])
out = GpuArrayType(broadcastable=br,
dtype=odtype,
context_name=ctx_name)()
return Apply(self, [pvals, unis, as_scalar(n)], [out])
评论列表
文章目录