def build_model():
model_name = args.modality + "_" + args.arch
model = models.__dict__[model_name](pretrained=True, num_classes=101)
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
return model
评论列表
文章目录