pytorch.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:foolbox 作者: bethgelab 项目源码 文件源码
def batch_predictions(self, images):
        # lazy import
        import torch
        from torch.autograd import Variable

        images = self._process_input(images)
        n = len(images)
        images = torch.from_numpy(images)
        if self.cuda:  # pragma: no cover
            images = images.cuda()
        images = Variable(images, volatile=True)
        predictions = self._model(images)
        predictions = predictions.data
        if self.cuda:  # pragma: no cover
            predictions = predictions.cpu()
        predictions = predictions.numpy()
        assert predictions.ndim == 2
        assert predictions.shape == (n, self.num_classes())
        return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号