def modify_densenets(model):
# Modify attributs
model.last_linear = model.classifier
del model.classifier
def logits(self, features):
x = F.relu(features, inplace=True)
x = F.avg_pool2d(x, kernel_size=7, stride=1)
x = x.view(x.size(0), -1)
x = self.last_linear(x)
return x
def forward(self, input):
x = self.features(input)
x = self.logits(x)
return x
# Modify methods
setattr(model.__class__, 'logits', logits)
setattr(model.__class__, 'forward', forward)
return model
torchvision_models.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录