def forward(self, x):
batchsize = x.size()[0]
trans = self.stn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans)
x = x.transpose(2,1)
x = F.relu(self.bn1(self.conv1(x)))
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = self.mp1(x)
x = x.view(-1, 1024)
if self.global_feat:
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
return torch.cat([x, pointfeat], 1), trans
评论列表
文章目录