torch.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:emu 作者: mlosch 项目源码 文件源码
def _nn_forward_hook(self, module, input, output, name=''):
        if type(output) is list:
            self.blobs[name] = [o.data.clone() for o in output]
        else:
            self.blobs[name] = output.data.clone()

    # @staticmethod
    # def _load_model_config(model_def):
    #     if isinstance(model_def, torch.nn.Module):
    #
    #     elif '.' not in os.path.basename(model_def):
    #         import torchvision.models as models
    #         if model_def not in models.__dict__:
    #             raise KeyError('Model {} does not exist in pytorchs model zoo.'.format(model_def))
    #         print('Loading model {} from pytorch model zoo'.format(model_def))
    #         return models.__dict__[model_def](pretrained=True)
    #     else:
    #         print('Loading model from {}'.format(model_def))
    #         if model_def.endswith('.t7'):
    #             return load_legacy_model(model_def)
    #         else:
    #             return torch.load(model_def)
    #
    #
    #     if type(model_cfg) == str:
    #         if not os.path.exists(model_cfg):
    #             try:
    #                 class_ = getattr(applications, model_cfg)
    #                 return class_(weights=model_weights)
    #             except AttributeError:
    #                 available_mdls = [attr for attr in dir(applications) if callable(getattr(applications, attr))]
    #                 raise ValueError('Could not load pretrained model with key {}. '
    #                                  'Available models: {}'.format(model_cfg, ', '.join(available_mdls)))
    #
    #         with open(model_cfg, 'r') as fileh:
    #             try:
    #                 return model_from_json(fileh)
    #             except ValueError:
    #                 pass
    #
    #             try:
    #                 return model_from_yaml(fileh)
    #             except ValueError:
    #                 pass
    #
    #         raise ValueError('Could not load model from configuration file {}. '
    #                          'Make sure the path is correct and the file format is yaml or json.'.format(model_cfg))
    #     elif type(model_cfg) == dict:
    #         return Model.from_config(model_cfg)
    #     elif type(model_cfg) == list:
    #         return Sequential.from_config(model_cfg)
    #
    #     raise ValueError('Could not load model from configuration object of type {}.'.format(type(model_cfg)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号