def hard_negative_mining(self, conf_loss, pos):
'''Return negative indices that is 3x the number as postive indices.
Args:
conf_loss: (tensor) cross entroy loss between conf_preds and conf_targets, sized [N*8732,].
pos: (tensor) positive(matched) box indices, sized [N,8732].
Return:
(tensor) negative indices, sized [N,8732].
'''
batch_size, num_boxes = pos.size()
# print(pos)
# print(conf_loss.size())
conf_loss[pos] = 0 # set pos boxes = 0, the rest are neg conf_loss
conf_loss = conf_loss.view(batch_size, -1) # [N,8732]
_,idx = conf_loss.sort(1, descending=True) # sort by neg conf_loss
_,rank = idx.sort(1) # [N,8732]
num_pos = pos.long().sum(1) # [N,1]
num_neg = torch.clamp(3*num_pos, max=num_boxes-1) # [N,1]
neg = rank < num_neg.view(-1, 1).expand_as(rank) # [N,8732]
return neg
multibox_loss.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录