def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.pool3(x)
x = self.inception4a(x)
x = self.inception4b(x)
x = self.inception5a(x)
x = self.inception5b(x)
x = self.inception6a(x)
x = self.inception6b(x)
if self.cut_at_pooling:
return x
x = self.avgpool(x)
x = x.view(x.size(0), -1)
if self.has_embedding:
x = self.feat(x)
x = self.feat_bn(x)
if self.norm:
x = F.normalize(x)
elif self.has_embedding:
x = F.relu(x)
if self.dropout > 0:
x = self.drop(x)
if self.num_classes > 0:
x = self.classifier(x)
return x
评论列表
文章目录