def generate_model():
class DenseModel(nn.Module):
def __init__(self, pretrained_model):
super(DenseModel, self).__init__()
self.classifier = nn.Linear(pretrained_model.classifier.in_features, 17)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
self.features = pretrained_model.features
self.dense1 = pretrained_model.features._modules['denseblock1']
self.dense2 = pretrained_model.features._modules['denseblock2']
self.dense3 = pretrained_model.features._modules['denseblock3']
self.dense4 = pretrained_model.features._modules['denseblock4']
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.avg_pool2d(out, kernel_size=8).view(features.size(0), -1)
out = F.sigmoid(self.classifier(out))
return out
return DenseModel(torchvision.models.densenet121(pretrained=True))
nn_finetune_densenet_121.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录