def __call__(self, x):
# Apply a mask to the filters (optional)
if self.filter_mask is not None:
w, m = F.broadcast(self.W, Variable(self.filter_mask))
w = w * m
# w = self.W * Variable(self.filter_mask)
else:
w = self.W
# Transform the filters
# w.shape == (out_channels, in_channels, input_stabilizer_size, ksize, ksize)
# tw.shape == (out_channels, output_stabilizer_size, in_channels, input_stabilizer_size, ksize, ksize)
tw = TransformGFilter(self.inds)(w)
# Fold the transformed filters
tw_shape = (self.out_channels * self.output_stabilizer_size,
self.in_channels * self.input_stabilizer_size,
self.ksize, self.ksize)
tw = F.Reshape(tw_shape)(tw)
# If flat_channels is False, we need to flatten the input feature maps to have a single 1d feature dimension.
if not self.flat_channels:
batch_size = x.data.shape[0]
in_ny, in_nx = x.data.shape[-2:]
x = F.reshape(x, (batch_size, self.in_channels * self.input_stabilizer_size, in_ny, in_nx))
# Perform the 2D convolution
y = F.convolution_2d(x, tw, b=None, stride=self.stride, pad=self.pad, use_cudnn=self.use_cudnn)
# Unfold the output feature maps
# We do this even if flat_channels is True, because we need to add the same bias to each G-feature map
batch_size, _, ny_out, nx_out = y.data.shape
y = F.reshape(y, (batch_size, self.out_channels, self.output_stabilizer_size, ny_out, nx_out))
# Add a bias to each G-feature map
if self.usebias:
bb = F.Reshape((1, self.out_channels, 1, 1, 1))(self.b)
y, b = F.broadcast(y, bb)
y = y + b
# Flatten feature channels if needed
if self.flat_channels:
n, nc, ng, nx, ny = y.data.shape
y = F.reshape(y, (n, nc * ng, nx, ny))
return y
评论列表
文章目录