def _prepare_base_model(self, base_model):
if 'resnet' in base_model or 'vgg' in base_model:
self.base_model = getattr(torchvision.models, base_model)(True)
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
if self.modality == 'Flow':
self.input_mean = [0.5]
self.input_std = [np.mean(self.input_std)]
elif self.modality == 'RGBDiff':
self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length
elif base_model == 'BNInception':
import model_zoo
self.base_model = getattr(model_zoo, base_model)()
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [104, 117, 128]
self.input_std = [1]
if self.modality == 'Flow':
self.input_mean = [128]
elif self.modality == 'RGBDiff':
self.input_mean = self.input_mean * (1 + self.new_length)
elif base_model == 'InceptionV3':
import model_zoo
self.base_model = getattr(model_zoo, base_model)()
self.base_model.last_layer_name = 'top_cls_fc'
self.input_size = 299
self.input_mean = [104, 117, 128]
self.input_std = [1]
if self.modality == 'Flow':
self.input_mean = [128]
elif self.modality == 'RGBDiff':
self.input_mean = self.input_mean * (1 + self.new_length)
elif 'inception' in base_model:
import model_zoo
self.base_model = getattr(model_zoo, base_model)()
self.base_model.last_layer_name = 'classif'
self.input_size = 299
self.input_mean = [0.5]
self.input_std = [0.5]
else:
raise ValueError('Unknown base model: {}'.format(base_model))
评论列表
文章目录