pytorch.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:foolbox 作者: bethgelab 项目源码 文件源码
def _loss_fn(self, image, label):
        # lazy import
        import torch
        import torch.nn as nn
        from torch.autograd import Variable

        image = self._process_input(image)
        target = np.array([label])
        target = torch.from_numpy(target)
        if self.cuda:  # pragma: no cover
            target = target.cuda()
        target = Variable(target)

        images = torch.from_numpy(image[None])
        if self.cuda:  # pragma: no cover
            images = images.cuda()
        images = Variable(images, volatile=True)
        predictions = self._model(images)
        ce = nn.CrossEntropyLoss()
        loss = ce(predictions, target)
        loss = loss.data
        if self.cuda:  # pragma: no cover
            loss = loss.cpu()
        loss = loss.numpy()
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号