def forward_gpu(self, inputs):
w, = inputs
xp = cuda.get_array_module(w)
och, ich, _, ny, nx = w.shape
nto, nti = self.T.shape[:2]
rotated_w = xp.empty((och, nto, ich, nti, ny, nx), dtype=w.dtype)
index_group_func_kernel(
input=w,
T=self.T,
U=self.U,
V=self.V,
output=rotated_w
)
return rotated_w,
评论列表
文章目录