def __init__(self, args):
nclass=args.nclass
super(Net, self).__init__()
self.backbone = args.backbone
# copying modules from pretrained models
if self.backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True)
elif self.backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True)
elif self.backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True)
else:
raise RuntimeError('unknown backbone: {}'.format(self.backbone))
n_codes = 32
self.head = nn.Sequential(
nn.Conv2d(2048, 128, 1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
encoding.nn.Encoding(D=128,K=n_codes),
encoding.nn.View(-1, 128*n_codes),
encoding.nn.Normalize(),
nn.Linear(128*n_codes, nclass),
)
评论列表
文章目录