def __call__(self, x):
r = self.r
out = self.conv(x)
batchsize = out.shape[0]
in_channels = out.shape[1]
out_channels = int(in_channels / (r ** 2))
in_height = out.shape[2]
in_width = out.shape[3]
out_height = in_height * r
out_width = in_width * r
out = F.reshape(out, (batchsize, r, r, out_channels, in_height, in_width))
out = F.transpose(out, (0, 3, 4, 1, 5, 2))
out = F.reshape(out, (batchsize, out_channels, out_height, out_width))
return out
评论列表
文章目录