def _nn_forward_hook(self, module, input, output, name=''):
if type(output) is list:
self.blobs[name] = [o.data.clone() for o in output]
else:
self.blobs[name] = output.data.clone()
# @staticmethod
# def _load_model_config(model_def):
# if isinstance(model_def, torch.nn.Module):
#
# elif '.' not in os.path.basename(model_def):
# import torchvision.models as models
# if model_def not in models.__dict__:
# raise KeyError('Model {} does not exist in pytorchs model zoo.'.format(model_def))
# print('Loading model {} from pytorch model zoo'.format(model_def))
# return models.__dict__[model_def](pretrained=True)
# else:
# print('Loading model from {}'.format(model_def))
# if model_def.endswith('.t7'):
# return load_legacy_model(model_def)
# else:
# return torch.load(model_def)
#
#
# if type(model_cfg) == str:
# if not os.path.exists(model_cfg):
# try:
# class_ = getattr(applications, model_cfg)
# return class_(weights=model_weights)
# except AttributeError:
# available_mdls = [attr for attr in dir(applications) if callable(getattr(applications, attr))]
# raise ValueError('Could not load pretrained model with key {}. '
# 'Available models: {}'.format(model_cfg, ', '.join(available_mdls)))
#
# with open(model_cfg, 'r') as fileh:
# try:
# return model_from_json(fileh)
# except ValueError:
# pass
#
# try:
# return model_from_yaml(fileh)
# except ValueError:
# pass
#
# raise ValueError('Could not load model from configuration file {}. '
# 'Make sure the path is correct and the file format is yaml or json.'.format(model_cfg))
# elif type(model_cfg) == dict:
# return Model.from_config(model_cfg)
# elif type(model_cfg) == list:
# return Sequential.from_config(model_cfg)
#
# raise ValueError('Could not load model from configuration object of type {}.'.format(type(model_cfg)))
评论列表
文章目录