gridgen.py 文件源码

python
阅读 56 收藏 0 点赞 0 评论 0

项目:intel-cervical-cancer 作者: wangg12 项目源码 文件源码
def forward(self, input):
        if not self.aux_loss:
            return self.f(input)
        else:
            identity = torch.from_numpy(np.array([[1,0,0], [0,1,0]], dtype=np.float32))
            batch_identity = torch.zeros([input.size(0), 2,3])
            for i in range(input.size(0)):
                batch_identity[i] = identity

            if input.is_cuda:
                batch_identity = Variable(batch_identity.cuda())
            else:
                batch_identity = Variable(batch_identity)

            loss = torch.mul(input - batch_identity, input - batch_identity)
            loss = torch.sum(loss,1)
            loss = torch.sum(loss,2)

            return self.f(input), loss.view(-1,1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号