def forward(self, y, feat):
# torch.histc can only be implemented on CPU
# To calculate the total number of every class in one mini-batch. See Equation 4 in the paper
if self.use_cuda:
hist = Variable(torch.histc(y.cpu().data.float(),bins=self.num_classes,min=0,max=self.num_classes) + 1).cuda()
else:
hist = Variable(torch.histc(y.data.float(),bins=self.num_classes,min=0,max=self.num_classes) + 1)
centers_count = hist.index_select(0,y.long())
# To squeeze the Tenosr
batch_size = feat.size()[0]
feat = feat.view(batch_size, 1, 1, -1).squeeze()
# To check the dim of centers and features
if feat.size()[1] != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size()[1]))
centers_pred = self.centers.index_select(0, y.long())
diff = feat - centers_pred
loss = self.loss_weight * 1 / 2.0 * (diff.pow(2).sum(1) / centers_count).sum()
return loss
评论列表
文章目录