def patch(x, ksize=3, stride=1, pad=0):
xp = cuda.get_array_module(x.data)
b, ch, h, w = x.data.shape
w = xp.identity(ch * ksize * ksize, dtype=np.float32).reshape((ch * ksize * ksize, ch, ksize, ksize))
return F.convolution_2d(x, W=w, stride=stride, pad=pad)
评论列表
文章目录