def _call(self, *args, **kwargs):
axis = kwargs['axis'] if 'axis' in kwargs else None
if len(args) == 0:
raise Exception('number of arguments must be more than 0')
if builtins.any(
not isinstance(_, (core.ndarray, numpy.ndarray, numpy.generic))
for _ in args):
raise TypeError('Invalid argument type for \'{}\': ({})'.format(
self.name,
', '.join(repr(type(_)) for _ in args)))
def is_cupy_data(a):
return isinstance(a, (core.ndarray, numpy.generic))
if builtins.all(is_cupy_data(_) for _ in args):
types = [_.dtype for _ in args]
key = tuple(types)
if key not in self._memo:
if self.input_num is not None:
nin = self.input_num
else:
nin = len(args)
f = _get_fusion(self.func, nin, self.reduce,
self.post_map, self.identity, types, self.name)
self._memo[key] = f
f = self._memo[key]
if self.reduce is None:
return f(*args)
else:
return f(*args, axis=axis)
else:
if builtins.any(type(_) is core.ndarray for _ in args):
types = '.'.join(repr(type(_)) for _ in args)
message = "Can't fuse \n %s(%s)" % (self.name, types)
warnings.warn(message)
if self.reduce is None:
return self.func(*args)
elif axis is None:
return self.post_map(self.reduce(self.func(*args)))
else:
return self.post_map(self.reduce(self.func(*args), axis=axis))
评论列表
文章目录