torchvision_models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pretrained-models.pytorch 作者: Cadene 项目源码 文件源码
def modify_densenets(model):
    # Modify attributs
    model.last_linear = model.classifier
    del model.classifier

    def logits(self, features):
        x = F.relu(features, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7, stride=1)
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x

    # Modify methods
    setattr(model.__class__, 'logits', logits)
    setattr(model.__class__, 'forward', forward)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号