parray.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:neurodriver 作者: neurokernel 项目源码 文件源码
def mul(self, other):
        """
        multiply other with self
        inplace
        """
        if isinstance(other, PitchArray):
            if self.shape != other.shape:
                raise ValueError("array dimension misaligned")
            dtype = _get_inplace_dtype(self, other)
            if self.size:
                if self.M == 1:
                    func = pu.get_mularray_function(
                        self.dtype, other.dtype, self.dtype, pitch = False)
                    func.prepared_call(
                        self._grid, self._block, self.gpudata,
                        self.gpudata, other.gpudata, self.size)
                else:
                    func = pu.get_mularray_function(
                        self.dtype, other.dtype, self.dtype, pitch = True)
                    func.prepared_call(self._grid, self._block,
                        self.M, self.N, self.gpudata, self.ld,
                        self.gpudata, self.ld, other.gpudata, other.ld)
            return self
        elif issubclass(type(other), (float, int, complex, np.integer,
                                      np.floating, np.complexfloating)):
            dtype = _get_inplace_dtype_with_scalar(other, self)
            if other != 1:
                if self.size:
                    if self.M == 1:
                        func = pu.get_mulscalar_function(
                            self.dtype, self.dtype, pitch = False)
                        func.prepared_call(
                            self._grid, self._block, self.gpudata,
                            self.gpudata, other, self.size)
                    else:
                        func = pu.get_mulscalar_function(
                            self.dtype, self.dtype, pitch = True)
                        func.prepared_call(
                            self._grid, self._block, self.M, self.N,
                            self.gpudata, self.ld, self.gpudata,
                            self.ld, other)
            return self
        else:
            raise TypeError("type of object to be multiplied"
                            "is not supported")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号