def make_node(self, input):
ctx_name = infer_context_name(input)
res = CAReduceDtype.make_node(self, input)
input = as_gpuarray_variable(input, ctx_name)
otype = GpuArrayType(dtype=res.outputs[0].dtype,
broadcastable=res.outputs[0].broadcastable,
context_name=ctx_name)
if res.op.axis is not None:
redux = []
for i in range(len(input.type.broadcastable)):
redux.append(i in res.op.axis)
# since redux is just another way to describe what is in axis
# it doesn't need to be compared in __eq__ or __hash__
res.op.redux = redux
return Apply(res.op, [input], [otype()])
评论列表
文章目录