python类histc()的实例源码

CenterLoss.py 文件源码 项目:MNIST_center_loss_pytorch 作者: jxgu1016 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def forward(self, y, feat):
        # torch.histc can only be implemented on CPU
        # To calculate the total number of every class in one mini-batch. See Equation 4 in the paper
        if self.use_cuda:
            hist = Variable(torch.histc(y.cpu().data.float(),bins=self.num_classes,min=0,max=self.num_classes) + 1).cuda()
        else:
            hist = Variable(torch.histc(y.data.float(),bins=self.num_classes,min=0,max=self.num_classes) + 1)

        centers_count = hist.index_select(0,y.long())


        # To squeeze the Tenosr
        batch_size = feat.size()[0]
        feat = feat.view(batch_size, 1, 1, -1).squeeze()
        # To check the dim of centers and features
        if feat.size()[1] != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size()[1]))

        centers_pred = self.centers.index_select(0, y.long())
        diff = feat - centers_pred
        loss = self.loss_weight * 1 / 2.0 * (diff.pow(2).sum(1) / centers_count).sum()
        return loss
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_histc(self):
        x = torch.Tensor((2, 4, 2, 2, 5, 4))
        y = torch.histc(x, 5, 1, 5) # nbins,  min,  max
        z = torch.Tensor((0, 3, 0, 2, 1))
        self.assertEqual(y, z)
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_histc(self):
        x = torch.Tensor((2, 4, 2, 2, 5, 4))
        y = torch.histc(x, 5, 1, 5)  # nbins,  min,  max
        z = torch.Tensor((0, 3, 0, 2, 1))
        self.assertEqual(y, z)
test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_histc(self):
        x = torch.Tensor((2, 4, 2, 2, 5, 4))
        y = torch.histc(x, 5, 1, 5)  # nbins,  min,  max
        z = torch.Tensor((0, 3, 0, 2, 1))
        self.assertEqual(y, z)
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def test_histc(self):
        x = torch.Tensor((2, 4, 2, 2, 5, 4))
        y = torch.histc(x, 5, 1, 5)  # nbins,  min,  max
        z = torch.Tensor((0, 3, 0, 2, 1))
        self.assertEqual(y, z)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_histc(self):
        x = torch.Tensor((2, 4, 2, 2, 5, 4))
        y = torch.histc(x, 5, 1, 5)  # nbins,  min,  max
        z = torch.Tensor((0, 3, 0, 2, 1))
        self.assertEqual(y, z)


问题


面经


文章

微信
公众号

扫码关注公众号