def forward_gpu(self, inputs):
x = inputs[0]
W = inputs[1]
# Prepare BLAS call
handle = cuda.Device().cublas_handle
k, m = W.shape
n, l = x.shape[0] * x.shape[1], x.shape[2]
lda = max(1, x.shape[-1])
ldb = max(1, W.strides[0] // W.dtype.itemsize)
ldc = max(1, m)
Wx = cupy.empty((x.shape[0], x.shape[1], W.shape[1]),
dtype=numpy.float32)
sgemm(handle, False, False, m, n, k, 1, W.data.ptr, ldb,
x.data.ptr, lda, 0, Wx.data.ptr, ldc)
if len(inputs) > 2:
b = inputs[2]
Wx += b
return Wx,
评论列表
文章目录