extract_features.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:clevr-iep 作者: facebookresearch 项目源码 文件源码
def build_model(args):
  if not hasattr(torchvision.models, args.model):
    raise ValueError('Invalid model "%s"' % args.model)
  if not 'resnet' in args.model:
    raise ValueError('Feature extraction only supports ResNets')
  cnn = getattr(torchvision.models, args.model)(pretrained=True)
  layers = [
    cnn.conv1,
    cnn.bn1,
    cnn.relu,
    cnn.maxpool,
  ]
  for i in range(args.model_stage):
    name = 'layer%d' % (i + 1)
    layers.append(getattr(cnn, name))
  model = torch.nn.Sequential(*layers)
  model.cuda()
  model.eval()
  return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号