def forward(self, x):
for name, module in self.base._modules.items():
if name == 'avgpool':
break
x = module(x)
if self.cut_at_pooling:
return x
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(x.size(0), -1)
if self.has_embedding:
x = self.feat(x)
x = self.feat_bn(x)
if self.norm:
x = F.normalize(x)
elif self.has_embedding:
x = F.relu(x)
if self.dropout > 0:
x = self.drop(x)
if self.num_classes > 0:
x = self.classifier(x)
return x
评论列表
文章目录