def resnet50_wildcat(num_classes, pretrained=True, kmax=1, kmin=None, alpha=1, num_maps=1):
model = models.resnet50(pretrained)
pooling = nn.Sequential()
pooling.add_module('class_wise', ClassWisePool(num_maps))
pooling.add_module('spatial', WildcatPool2d(kmax, kmin, alpha))
return ResNetWSL(model, num_classes * num_maps, pooling=pooling)
评论列表
文章目录