def make_node(self, o, W, h, inputIdx, outputIdx):
ctx = infer_context_name(o, W, h)
o = as_gpuarray_variable(o, ctx)
W = as_gpuarray_variable(W, ctx)
h = as_gpuarray_variable(h, ctx)
inputIdx = as_tensor_variable(inputIdx)
outputIdx = as_tensor_variable(outputIdx)
assert o.ndim == 3
assert W.ndim == 4
assert h.ndim == 3
assert inputIdx.ndim == 2
assert outputIdx.ndim == 2
assert inputIdx.type.dtype in discrete_dtypes
assert outputIdx.type.dtype in discrete_dtypes
return Apply(self, [o, W, h, inputIdx, outputIdx],
[o.type()])
评论列表
文章目录