def forward(self, x):
x_shape = x.size() # (b, c, h, w)
offset = self.offset_filter(x) # (b, 2*c, h, w)
offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w)
offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w)
offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w)
if not self.input_shape or self.input_shape != x_shape:
self.input_shape = x_shape
grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w)
grid_w = torch.Tensor(grid_w)
grid_h = torch.Tensor(grid_h)
if self.cuda:
grid_w = grid_w.cuda()
grid_h = grid_h.cuda()
self.grid_w = nn.Parameter(grid_w)
self.grid_h = nn.Parameter(grid_h)
offset_w = offset_w + self.grid_w # (b*c, h, w)
offset_h = offset_h + self.grid_h # (b*c, h, w)
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w)
x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w)
x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w)
x = self.regular_filter(x)
return x
评论列表
文章目录