image_featurizers.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def init_cnn(self, opt):
        """Lazy initialization of preprocessor model in case we don't need any
        image preprocessing.
        """
        try:
            import torch
            self.use_cuda = (not opt.get('no_cuda', False)
                             and torch.cuda.is_available())
            self.torch = torch
        except ModuleNotFoundError:
            raise ModuleNotFoundError('Need to install Pytorch: go to pytorch.org')
        from torch.autograd import Variable
        import torchvision
        import torchvision.transforms as transforms
        import torch.nn as nn

        try:
            import h5py
            self.h5py = h5py
        except ModuleNotFoundError:
            raise ModuleNotFoundError('Need to install h5py')

        if 'image_mode' not in opt or 'image_size' not in opt:
            raise RuntimeError(
                'Need to add image arguments to opt. See '
                'parlai.core.params.ParlaiParser.add_image_args')
        self.image_mode = opt['image_mode']
        self.image_size = opt['image_size']
        self.crop_size = opt['image_cropsize']

        if self.use_cuda:
            print('[ Using CUDA ]')
            torch.cuda.set_device(opt.get('gpu', -1))

        cnn_type, layer_num = self.image_mode_switcher()

        # initialize the pretrained CNN using pytorch.
        CNN = getattr(torchvision.models, cnn_type)

        # cut off the additional layer.
        self.netCNN = nn.Sequential(
            *list(CNN(pretrained=True).children())[:layer_num])

        # initialize the transform function using torch vision.
        self.transform = transforms.Compose([
            transforms.Scale(self.image_size),
            transforms.CenterCrop(self.crop_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        # container for single image
        self.xs = torch.zeros(1, 3, self.crop_size, self.crop_size)

        if self.use_cuda:
            self.netCNN.cuda()
            self.xs = self.xs.cuda()

        # make self.xs variable.
        self.xs = Variable(self.xs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号