fusion.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:cupy 作者: cupy 项目源码 文件源码
def _call(self, *args, **kwargs):
        axis = kwargs['axis'] if 'axis' in kwargs else None
        if len(args) == 0:
            raise Exception('number of arguments must be more than 0')
        if builtins.any(
                not isinstance(_, (core.ndarray, numpy.ndarray, numpy.generic))
                for _ in args):
            raise TypeError('Invalid argument type for \'{}\': ({})'.format(
                self.name,
                ', '.join(repr(type(_)) for _ in args)))

        def is_cupy_data(a):
            return isinstance(a, (core.ndarray, numpy.generic))
        if builtins.all(is_cupy_data(_) for _ in args):
            types = [_.dtype for _ in args]
            key = tuple(types)
            if key not in self._memo:
                if self.input_num is not None:
                    nin = self.input_num
                else:
                    nin = len(args)
                f = _get_fusion(self.func, nin, self.reduce,
                                self.post_map, self.identity, types, self.name)
                self._memo[key] = f
            f = self._memo[key]
            if self.reduce is None:
                return f(*args)
            else:
                return f(*args, axis=axis)
        else:
            if builtins.any(type(_) is core.ndarray for _ in args):
                types = '.'.join(repr(type(_)) for _ in args)
                message = "Can't fuse \n %s(%s)" % (self.name, types)
                warnings.warn(message)
            if self.reduce is None:
                return self.func(*args)
            elif axis is None:
                return self.post_map(self.reduce(self.func(*args)))
            else:
                return self.post_map(self.reduce(self.func(*args), axis=axis))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号