def forward(self, input1):
self.input1 = input1
output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
for i in range(input1.size(0)):
self.batchgrid[i] = self.grid
if input1.is_cuda:
self.batchgrid = self.batchgrid.cuda()
output = output.cuda()
batchgrid_temp = self.batchgrid.view(-1, self.height*self.width, 3)
batchgrid_temp.contiguous()
input_temp = torch.transpose(input1, 1, 2)
input_temp.contiguous()
output_temp = torch.bmm(batchgrid_temp, input_temp)
output = output_temp.view(-1, self.height, self.width, 2)
output.contiguous()
return output
评论列表
文章目录