def forward(self, loc_preds, loc_targets, conf_preds, conf_targets):
'''Compute loss between (loc_preds, loc_targets) and (conf_preds, conf_targets).
Args:
loc_preds: (tensor) predicted locations, sized [batch_size, 8732, 4].
loc_targets: (tensor) encoded target locations, sized [batch_size, 8732, 4].
conf_preds: (tensor) predicted class confidences, sized [batch_size, 8732, num_classes].
conf_targets: (tensor) encoded target classes, sized [batch_size, 8732].
loss:
(tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + CrossEntropyLoss(conf_preds, conf_targets).
'''
batch_size, num_boxes, _ = loc_preds.size()
pos = conf_targets>0 # [N,8732], pos means the box matched.
# print(pos.size())
num_matched_boxes = pos.data.long().sum()
if num_matched_boxes == 0:
return Variable(torch.Tensor([0])), Variable(torch.Tensor([0]))
################################################################
# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
################################################################
pos_mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,8732,4]
pos_loc_preds = loc_preds[pos_mask].view(-1,4) # [#pos,4]
# print(pos_loc_preds.size())
pos_loc_targets = loc_targets[pos_mask].view(-1,4) # [#pos,4]
loc_loss = F.smooth_l1_loss(pos_loc_preds, pos_loc_targets, size_average=False)
################################################################
# conf_loss = CrossEntropyLoss(pos_conf_preds, pos_conf_targets)
# + CrossEntropyLoss(neg_conf_preds, neg_conf_targets)
################################################################
# print('1',conf_preds.size()) # [N, 8732, 16]
# print('2',conf_targets.size()) # [N, 8732]
conf_loss = self.cross_entropy_loss(conf_preds.view(-1,self.num_classes), \
conf_targets.view(-1)) # [N*8732,]
# print(conf_loss.size()) # [8732, 8732]
conf_loss = conf_loss.view(-1, 8732)
neg = self.hard_negative_mining(conf_loss, pos) # [N,8732]
pos_mask = pos.unsqueeze(2).expand_as(conf_preds) # [N,8732,21]
# print(conf_preds.size()) # [N, 8732, 16]
# print(neg.size()) # [N, 8732*8732]
neg_mask = neg.unsqueeze(2).expand_as(conf_preds) # [N,8732,21]
mask = (pos_mask+neg_mask).gt(0)
pos_and_neg = (pos+neg).gt(0)
preds = conf_preds[mask].view(-1,self.num_classes) # [#pos+#neg,21]
targets = conf_targets[pos_and_neg] # [#pos+#neg,]
conf_loss = F.cross_entropy(preds, targets, size_average=False)
loc_loss /= num_matched_boxes
conf_loss /= num_matched_boxes
# print('loc_loss: %f conf_loss: %f' % (loc_loss.data[0], conf_loss.data[0]), end=' ')
return loc_loss, conf_loss
multibox_loss.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录