sequence_linear.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:nn_mask 作者: ZitengWang 项目源码 文件源码
def forward_gpu(self, inputs):
        x = inputs[0]
        W = inputs[1]
        # Prepare BLAS call
        handle = cuda.Device().cublas_handle
        k, m = W.shape
        n, l = x.shape[0] * x.shape[1], x.shape[2]
        lda = max(1, x.shape[-1])
        ldb = max(1, W.strides[0] // W.dtype.itemsize)
        ldc = max(1, m)
        Wx = cupy.empty((x.shape[0], x.shape[1], W.shape[1]),
                        dtype=numpy.float32)
        sgemm(handle, False, False, m, n, k, 1, W.data.ptr, ldb,
              x.data.ptr, lda, 0, Wx.data.ptr, ldc)
        if len(inputs) > 2:
            b = inputs[2]
            Wx += b
        return Wx,
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号