deepten.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def __init__(self, args):
        nclass=args.nclass
        super(Net, self).__init__()
        self.backbone = args.backbone
        # copying modules from pretrained models
        if self.backbone == 'resnet50':
            self.pretrained = resnet.resnet50(pretrained=True)
        elif self.backbone == 'resnet101':
            self.pretrained = resnet.resnet101(pretrained=True)
        elif self.backbone == 'resnet152':
            self.pretrained = resnet.resnet152(pretrained=True)
        else:
            raise RuntimeError('unknown backbone: {}'.format(self.backbone))
        n_codes = 32
        self.head = nn.Sequential(
            nn.Conv2d(2048, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            encoding.nn.Encoding(D=128,K=n_codes),
            encoding.nn.View(-1, 128*n_codes),
            encoding.nn.Normalize(),
            nn.Linear(128*n_codes, nclass),
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号