def forward(self, depth, trans0, trans1, rotate):
self.batchgrid3d = torch.zeros(torch.Size([depth.size(0)]) + self.grid3d.size())
for i in range(depth.size(0)):
self.batchgrid3d[i] = self.grid3d
self.batchgrid3d = Variable(self.batchgrid3d)
self.batchgrid = torch.zeros(torch.Size([depth.size(0)]) + self.grid.size())
for i in range(depth.size(0)):
self.batchgrid[i] = self.grid
self.batchgrid = Variable(self.batchgrid)
x = self.batchgrid3d[:,:,:,0:1] * depth + trans0.view(-1,1,1,1).repeat(1, self.height, self.width, 1)
y = self.batchgrid3d[:,:,:,1:2] * depth + trans1.view(-1,1,1,1).repeat(1, self.height, self.width, 1)
z = self.batchgrid3d[:,:,:,2:3] * depth
#print(x.size(), y.size(), z.size())
r = torch.sqrt(x**2 + y**2 + z**2) + 1e-5
#print(r)
theta = torch.acos(z/r)/(np.pi/2) - 1
#phi = torch.atan(y/x)
phi = torch.atan(y/(x + 1e-5)) + np.pi * x.lt(0).type(torch.FloatTensor) * (y.ge(0).type(torch.FloatTensor) - y.lt(0).type(torch.FloatTensor))
phi = phi/np.pi
#print(theta.size(), phi.size())
input_u = rotate.view(-1,1,1,1).repeat(1,self.height, self.width,1)
output = torch.cat([theta,phi], 3)
#print(output.size())
output1 = torch.atan(torch.tan(np.pi/2.0*(output[:,:,:,1:2] + self.batchgrid[:,:,:,2:] * input_u[:,:,:,:]))) /(np.pi/2)
output2 = torch.cat([output[:,:,:,0:1], output1], 3)
return output2
评论列表
文章目录