def forward(self, input1):
self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
for i in range(input1.size(0)):
self.batchgrid[i] = self.grid
self.batchgrid = Variable(self.batchgrid)
if input1.is_cuda:
self.batchgrid = self.batchgrid.cuda()
output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 2)
return output
评论列表
文章目录