def forward(self, input):
x = input
if x.data.is_cuda and self.gpuDevice != 0:
x = x.cuda(self.gpuDevice)
#
if x.size()[-1] == 128:
x = self.resize2(self.resize1(x))
x = self.layer8(self.layer7(self.layer6(self.layer5(
self.layer4(self.layer3(self.layer2(self.layer1(x))))))))
x = self.layer13(self.layer12(
self.layer11(self.layer10(self.layer9(x)))))
x = self.layer14(x)
x = self.layer15(x)
x = self.layer16(x)
x = self.layer17(x)
x = self.layer18(x)
x = self.layer19(x)
x = self.layer21(x)
x = self.layer22(x)
x = x.view((-1, 736))
x_736 = x
x = self.layer25(x)
x_norm = torch.sqrt(torch.sum(x**2, 1) + 1e-6)
x = torch.div(x, x_norm.view(-1, 1).expand_as(x))
return (x, x_736)
评论列表
文章目录