def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x))) # bz x 256 x 2048
x = F.relu(self.bn2(self.conv2(x))) # bz x 1024 x 2048
x = self.mp1(x) # bz x 1024 x 1
x = x.view(-1, 1024)
x = F.relu(self.bn3(self.fc1(x))) # bz x 512
x = F.relu(self.bn4(self.fc2(x))) # bz x 256
x = self.fc3(x) # bz x (128*128)
# identity transform
# bz x (128*128)
iden = Variable(torch.from_numpy(np.eye(128).astype(np.float32))).view(1,128*128).repeat(batchsize,1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, 128, 128) # bz x 3 x 3
return x
评论列表
文章目录